mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
WIP multidevice training
This commit is contained in:
@@ -5,6 +5,10 @@ from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter, ControlNetModel
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
@@ -63,7 +67,7 @@ class TrainerV2(BaseSDTrainProcess):
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.unified_training_model: UnifiedTrainingModel = None
|
||||
|
||||
self.device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -137,7 +141,14 @@ class TrainerV2(BaseSDTrainProcess):
|
||||
embedding=self.embedding,
|
||||
timer=self.timer,
|
||||
trigger_word=self.trigger_word,
|
||||
gpu_ids=self.device_ids,
|
||||
)
|
||||
self.unified_training_model = nn.DataParallel(
|
||||
self.unified_training_model,
|
||||
device_ids=self.device_ids
|
||||
)
|
||||
|
||||
self.unified_training_model = self.unified_training_model.to(self.device_torch)
|
||||
|
||||
# call parent hook
|
||||
super().hook_before_train_loop()
|
||||
|
||||
Reference in New Issue
Block a user