diff --git a/extensions_built_in/sd_trainer/TrainerV2.py b/extensions_built_in/sd_trainer/TrainerV2.py index 8b4a1ebe..7bda9cd1 100644 --- a/extensions_built_in/sd_trainer/TrainerV2.py +++ b/extensions_built_in/sd_trainer/TrainerV2.py @@ -5,6 +5,10 @@ from typing import Union, List import numpy as np from diffusers import T2IAdapter, ControlNetModel +import torch.distributed as dist +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.data_loader import get_dataloader_datasets @@ -63,7 +67,7 @@ class TrainerV2(BaseSDTrainProcess): self.scaler._unscale_grads_ = _unscale_grads_replacer self.unified_training_model: UnifiedTrainingModel = None - + self.device_ids = list(range(torch.cuda.device_count())) def before_model_load(self): pass @@ -137,7 +141,14 @@ class TrainerV2(BaseSDTrainProcess): embedding=self.embedding, timer=self.timer, trigger_word=self.trigger_word, + gpu_ids=self.device_ids, ) + self.unified_training_model = nn.DataParallel( + self.unified_training_model, + device_ids=self.device_ids + ) + + self.unified_training_model = self.unified_training_model.to(self.device_torch) # call parent hook super().hook_before_train_loop() diff --git a/toolkit/models/unified_training_model.py b/toolkit/models/unified_training_model.py index aa6fdde1..ea9d4d2f 100644 --- a/toolkit/models/unified_training_model.py +++ b/toolkit/models/unified_training_model.py @@ -18,7 +18,7 @@ from toolkit.network_mixins import Network from toolkit.prompt_utils import PromptEmbeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork -from toolkit.timer import Timer +from toolkit.timer import Timer, DummyTimer from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight from toolkit.config_modules import TrainConfig, AdapterConfig @@ -40,6 +40,7 @@ class UnifiedTrainingModel(nn.Module): embedding: Optional[Embedding] = None, timer: Timer = None, trigger_word: Optional[str] = None, + gpu_ids: Optional[Union[int, list]] = None, ): super(UnifiedTrainingModel, self).__init__() self.sd: StableDiffusion = sd @@ -52,6 +53,8 @@ class UnifiedTrainingModel(nn.Module): self.timer: Timer = timer self.trigger_word: Optional[str] = trigger_word self.device_torch = torch.device("cuda") + self.gpu_ids = gpu_ids + self.primary_gpu_id = self.gpu_ids[0] # The first in the list is primary # misc config self.do_long_prompts = False @@ -61,6 +64,16 @@ class UnifiedTrainingModel(nn.Module): if self.train_config.do_prior_divergence: self.do_prior_prediction = True + + # register modules from sd + self.text_encoders = nn.ModuleList([self.sd.text_encoder] if not isinstance(self.sd.text_encoder, list) else self.sd.text_encoder) + self.unet = self.sd.unet + self.vae = self.sd.vae + + def is_primary_gpu(self): + return torch.cuda.current_device() == self.primary_gpu_id + + def before_unet_predict(self): pass @@ -808,6 +821,12 @@ class UnifiedTrainingModel(nn.Module): def forward(self, batch: DataLoaderBatchDTO): + if not self.is_primary_gpu(): + # replace timer with dummy one + self.timer = DummyTimer() + + self.device_torch = torch.cuda.current_device() + self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) diff --git a/toolkit/timer.py b/toolkit/timer.py index ca4fecba..60783f1b 100644 --- a/toolkit/timer.py +++ b/toolkit/timer.py @@ -63,3 +63,29 @@ class Timer: else: # There was an exception, cancel the timer self.cancel(self.current_timer) + + +class DummyTimer: + def __init__(self, name='Timer'): + self.name = name + + def start(self, timer_name): + pass + + def stop(self, timer_name): + pass + + def print(self): + pass + + def reset(self): + pass + + def __call__(self, timer_name): + return self + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + pass