import os import random from collections import OrderedDict from typing import Union, List import numpy as np from diffusers import T2IAdapter, ControlNetModel import torch.distributed as dist from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.data_loader import get_dataloader_datasets from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.stable_diffusion_model import BlankNetwork from toolkit.train_tools import get_torch_dtype, add_all_snr_to_noise_scheduler import gc import torch from jobs.process import BaseSDTrainProcess from torchvision import transforms from diffusers import EMAModel import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.unified_training_model import UnifiedTrainingModel def flush(): torch.cuda.empty_cache() gc.collect() adapter_transforms = transforms.Compose([ transforms.ToTensor(), ]) class TrainerV2(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] self.do_prior_prediction = False self.do_long_prompts = False self.do_guided_loss = False self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None self.scaler = torch.cuda.amp.GradScaler() self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" self.do_grad_scale = True if self.is_fine_tuning: self.do_grad_scale = False if self.adapter_config is not None: if self.adapter_config.train: self.do_grad_scale = False if self.train_config.dtype in ["fp16", "float16"]: # patch the scaler to allow fp16 training org_unscale_grads = self.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): return org_unscale_grads(optimizer, inv_scale, found_inf, True) self.scaler._unscale_grads_ = _unscale_grads_replacer self.unified_training_model: UnifiedTrainingModel = None self.device_ids = list(range(torch.cuda.device_count())) def before_model_load(self): pass def before_dataset_load(self): self.assistant_adapter = None # get adapter assistant if one is set if self.train_config.adapter_assist_name_or_path is not None: adapter_path = self.train_config.adapter_assist_name_or_path if self.train_config.adapter_assist_type == "t2i": # dont name this adapter since we are not training it self.assistant_adapter = T2IAdapter.from_pretrained( adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) ).to(self.device_torch) elif self.train_config.adapter_assist_type == "control_net": self.assistant_adapter = ControlNetModel.from_pretrained( adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) else: raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") self.assistant_adapter.eval() self.assistant_adapter.requires_grad_(False) flush() if self.train_config.train_turbo and self.train_config.show_turbo_outputs: raise ValueError("Turbo outputs are not supported on MultiGPUSDTrainer") def hook_before_train_loop(self): # if self.train_config.do_prior_divergence: # self.do_prior_prediction = True # move vae to device if we did not cache latents if not self.is_latents_cached: self.sd.vae.eval() self.sd.vae.to(self.device_torch) else: # offload it. Already cached self.sd.vae.to('cpu') flush() add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) if self.adapter is not None: self.adapter.to(self.device_torch) # check if we have regs and using adapter and caching clip embeddings has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) if has_reg and is_caching_clip_embeddings: # we need a list of unconditional clip image embeds from other datasets to handle regs unconditional_clip_image_embeds = [] datasets = get_dataloader_datasets(self.data_loader) for i in range(len(datasets)): unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache if len(unconditional_clip_image_embeds) == 0: raise ValueError("No unconditional clip image embeds found. This should not happen") self._clip_image_embeds_unconditional = unconditional_clip_image_embeds if self.train_config.negative_prompt is not None: raise ValueError("Negative prompt is not supported on MultiGPUSDTrainer") # setup the unified training model self.unified_training_model = UnifiedTrainingModel( sd=self.sd, network=self.network, adapter=self.adapter, assistant_adapter=self.assistant_adapter, train_config=self.train_config, adapter_config=self.adapter_config, embedding=self.embedding, timer=self.timer, trigger_word=self.trigger_word, gpu_ids=self.device_ids, ) self.unified_training_model = nn.DataParallel( self.unified_training_model, device_ids=self.device_ids ) self.unified_training_model = self.unified_training_model.to(self.device_torch) # call parent hook super().hook_before_train_loop() # you can expand these in a child class to make customization easier def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return self.unified_training_model.preprocess_batch(batch) def before_unet_predict(self): pass def after_unet_predict(self): pass def end_of_training_loop(self): pass def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): self.optimizer.zero_grad(set_to_none=True) loss = self.unified_training_model(batch) if torch.isnan(loss): print("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) if self.network is not None: network = self.network else: network = BlankNetwork() with (network): with self.timer('backward'): # todo we have multiplier seperated. works for now as res are not in same batch, but need to change # IMPORTANT if gradient checkpointing do not leave with network when doing backward # it will destroy the gradients. This is because the network is a context manager # and will change the multipliers back to 0.0 when exiting. They will be # 0.0 for the backward pass and the gradients will be 0.0 # I spent weeks on fighting this. DON'T DO IT # with fsdp_overlap_step_with_backward(): # if self.is_bfloat: # loss.backward() # else: if not self.do_grad_scale: loss.backward() else: self.scaler.scale(loss).backward() if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': if self.do_grad_scale: self.scaler.unscale_(self.optimizer) if isinstance(self.params[0], dict): for i in range(len(self.params)): torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): # self.optimizer.step() if not self.do_grad_scale: self.optimizer.step() else: self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) if self.ema is not None: with self.timer('ema_update'): self.ema.update() else: # gradient accumulation. Just a place for breakpoint pass # TODO Should we only step scheduler on grad step? If so, need to recalculate last step with self.timer('scheduler_step'): self.lr_scheduler.step() if self.embedding is not None: with self.timer('restore_embeddings'): # Let's make sure we don't update any embedding weights besides the newly added token self.embedding.restore_embeddings() if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): with self.timer('restore_adapter'): # Let's make sure we don't update any embedding weights besides the newly added token self.adapter.restore_embeddings() loss_dict = OrderedDict( {'loss': loss.item()} ) self.end_of_training_loop() return loss_dict