mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP multidevice training
This commit is contained in:
@@ -5,6 +5,10 @@ from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter, ControlNetModel
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
@@ -63,7 +67,7 @@ class TrainerV2(BaseSDTrainProcess):
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.unified_training_model: UnifiedTrainingModel = None
|
||||
|
||||
self.device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -137,7 +141,14 @@ class TrainerV2(BaseSDTrainProcess):
|
||||
embedding=self.embedding,
|
||||
timer=self.timer,
|
||||
trigger_word=self.trigger_word,
|
||||
gpu_ids=self.device_ids,
|
||||
)
|
||||
self.unified_training_model = nn.DataParallel(
|
||||
self.unified_training_model,
|
||||
device_ids=self.device_ids
|
||||
)
|
||||
|
||||
self.unified_training_model = self.unified_training_model.to(self.device_torch)
|
||||
|
||||
# call parent hook
|
||||
super().hook_before_train_loop()
|
||||
|
||||
@@ -18,7 +18,7 @@ from toolkit.network_mixins import Network
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.timer import Timer
|
||||
from toolkit.timer import Timer, DummyTimer
|
||||
from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight
|
||||
|
||||
from toolkit.config_modules import TrainConfig, AdapterConfig
|
||||
@@ -40,6 +40,7 @@ class UnifiedTrainingModel(nn.Module):
|
||||
embedding: Optional[Embedding] = None,
|
||||
timer: Timer = None,
|
||||
trigger_word: Optional[str] = None,
|
||||
gpu_ids: Optional[Union[int, list]] = None,
|
||||
):
|
||||
super(UnifiedTrainingModel, self).__init__()
|
||||
self.sd: StableDiffusion = sd
|
||||
@@ -52,6 +53,8 @@ class UnifiedTrainingModel(nn.Module):
|
||||
self.timer: Timer = timer
|
||||
self.trigger_word: Optional[str] = trigger_word
|
||||
self.device_torch = torch.device("cuda")
|
||||
self.gpu_ids = gpu_ids
|
||||
self.primary_gpu_id = self.gpu_ids[0] # The first in the list is primary
|
||||
|
||||
# misc config
|
||||
self.do_long_prompts = False
|
||||
@@ -61,6 +64,16 @@ class UnifiedTrainingModel(nn.Module):
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
|
||||
# register modules from sd
|
||||
self.text_encoders = nn.ModuleList([self.sd.text_encoder] if not isinstance(self.sd.text_encoder, list) else self.sd.text_encoder)
|
||||
self.unet = self.sd.unet
|
||||
self.vae = self.sd.vae
|
||||
|
||||
def is_primary_gpu(self):
|
||||
return torch.cuda.current_device() == self.primary_gpu_id
|
||||
|
||||
|
||||
def before_unet_predict(self):
|
||||
pass
|
||||
|
||||
@@ -808,6 +821,12 @@ class UnifiedTrainingModel(nn.Module):
|
||||
|
||||
|
||||
def forward(self, batch: DataLoaderBatchDTO):
|
||||
if not self.is_primary_gpu():
|
||||
# replace timer with dummy one
|
||||
self.timer = DummyTimer()
|
||||
|
||||
self.device_torch = torch.cuda.current_device()
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
@@ -63,3 +63,29 @@ class Timer:
|
||||
else:
|
||||
# There was an exception, cancel the timer
|
||||
self.cancel(self.current_timer)
|
||||
|
||||
|
||||
class DummyTimer:
|
||||
def __init__(self, name='Timer'):
|
||||
self.name = name
|
||||
|
||||
def start(self, timer_name):
|
||||
pass
|
||||
|
||||
def stop(self, timer_name):
|
||||
pass
|
||||
|
||||
def print(self):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def __call__(self, timer_name):
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user