WIP multidevice training

This commit is contained in:
Jaret Burkett
2024-08-29 16:04:20 -06:00
parent a48c9aba8d
commit 4fa8fac5fd
3 changed files with 58 additions and 2 deletions

View File

@@ -5,6 +5,10 @@ from typing import Union, List
import numpy as np
from diffusers import T2IAdapter, ControlNetModel
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_datasets
@@ -63,7 +67,7 @@ class TrainerV2(BaseSDTrainProcess):
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.unified_training_model: UnifiedTrainingModel = None
self.device_ids = list(range(torch.cuda.device_count()))
def before_model_load(self):
pass
@@ -137,7 +141,14 @@ class TrainerV2(BaseSDTrainProcess):
embedding=self.embedding,
timer=self.timer,
trigger_word=self.trigger_word,
gpu_ids=self.device_ids,
)
self.unified_training_model = nn.DataParallel(
self.unified_training_model,
device_ids=self.device_ids
)
self.unified_training_model = self.unified_training_model.to(self.device_torch)
# call parent hook
super().hook_before_train_loop()