WIP multidevice training

This commit is contained in:
Jaret Burkett
2024-08-29 16:04:20 -06:00
parent a48c9aba8d
commit 4fa8fac5fd
3 changed files with 58 additions and 2 deletions

View File

@@ -5,6 +5,10 @@ from typing import Union, List
import numpy as np
from diffusers import T2IAdapter, ControlNetModel
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_datasets
@@ -63,7 +67,7 @@ class TrainerV2(BaseSDTrainProcess):
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.unified_training_model: UnifiedTrainingModel = None
self.device_ids = list(range(torch.cuda.device_count()))
def before_model_load(self):
pass
@@ -137,7 +141,14 @@ class TrainerV2(BaseSDTrainProcess):
embedding=self.embedding,
timer=self.timer,
trigger_word=self.trigger_word,
gpu_ids=self.device_ids,
)
self.unified_training_model = nn.DataParallel(
self.unified_training_model,
device_ids=self.device_ids
)
self.unified_training_model = self.unified_training_model.to(self.device_torch)
# call parent hook
super().hook_before_train_loop()

View File

@@ -18,7 +18,7 @@ from toolkit.network_mixins import Network
from toolkit.prompt_utils import PromptEmbeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.timer import Timer
from toolkit.timer import Timer, DummyTimer
from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight
from toolkit.config_modules import TrainConfig, AdapterConfig
@@ -40,6 +40,7 @@ class UnifiedTrainingModel(nn.Module):
embedding: Optional[Embedding] = None,
timer: Timer = None,
trigger_word: Optional[str] = None,
gpu_ids: Optional[Union[int, list]] = None,
):
super(UnifiedTrainingModel, self).__init__()
self.sd: StableDiffusion = sd
@@ -52,6 +53,8 @@ class UnifiedTrainingModel(nn.Module):
self.timer: Timer = timer
self.trigger_word: Optional[str] = trigger_word
self.device_torch = torch.device("cuda")
self.gpu_ids = gpu_ids
self.primary_gpu_id = self.gpu_ids[0] # The first in the list is primary
# misc config
self.do_long_prompts = False
@@ -61,6 +64,16 @@ class UnifiedTrainingModel(nn.Module):
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
# register modules from sd
self.text_encoders = nn.ModuleList([self.sd.text_encoder] if not isinstance(self.sd.text_encoder, list) else self.sd.text_encoder)
self.unet = self.sd.unet
self.vae = self.sd.vae
def is_primary_gpu(self):
return torch.cuda.current_device() == self.primary_gpu_id
def before_unet_predict(self):
pass
@@ -808,6 +821,12 @@ class UnifiedTrainingModel(nn.Module):
def forward(self, batch: DataLoaderBatchDTO):
if not self.is_primary_gpu():
# replace timer with dummy one
self.timer = DummyTimer()
self.device_torch = torch.cuda.current_device()
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)

View File

@@ -63,3 +63,29 @@ class Timer:
else:
# There was an exception, cancel the timer
self.cancel(self.current_timer)
class DummyTimer:
def __init__(self, name='Timer'):
self.name = name
def start(self, timer_name):
pass
def stop(self, timer_name):
pass
def print(self):
pass
def reset(self):
pass
def __call__(self, timer_name):
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass