Created a v2 trainer and moved all the training logic to single torch model so it can can be run in parallel

This commit is contained in:
Jaret Burkett
2024-08-29 12:34:18 -06:00
parent 60232def91
commit a48c9aba8d
3 changed files with 1674 additions and 1 deletions

View File

@@ -0,0 +1,238 @@
import os
import random
from collections import OrderedDict
from typing import Union, List
import numpy as np
from diffusers import T2IAdapter, ControlNetModel
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.stable_diffusion_model import BlankNetwork
from toolkit.train_tools import get_torch_dtype, add_all_snr_to_noise_scheduler
import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.unified_training_model import UnifiedTrainingModel
def flush():
torch.cuda.empty_cache()
gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class TrainerV2(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
self._clip_image_embeds_unconditional: Union[List[str], None] = None
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.scaler = torch.cuda.amp.GradScaler()
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
self.do_grad_scale = True
if self.is_fine_tuning:
self.do_grad_scale = False
if self.adapter_config is not None:
if self.adapter_config.train:
self.do_grad_scale = False
if self.train_config.dtype in ["fp16", "float16"]:
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.unified_training_model: UnifiedTrainingModel = None
def before_model_load(self):
pass
def before_dataset_load(self):
self.assistant_adapter = None
# get adapter assistant if one is set
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
if self.train_config.adapter_assist_type == "t2i":
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch)
elif self.train_config.adapter_assist_type == "control_net":
self.assistant_adapter = ControlNetModel.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
else:
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
raise ValueError("Turbo outputs are not supported on MultiGPUSDTrainer")
def hook_before_train_loop(self):
# if self.train_config.do_prior_divergence:
# self.do_prior_prediction = True
# move vae to device if we did not cache latents
if not self.is_latents_cached:
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
else:
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
if self.adapter is not None:
self.adapter.to(self.device_torch)
# check if we have regs and using adapter and caching clip embeddings
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
if has_reg and is_caching_clip_embeddings:
# we need a list of unconditional clip image embeds from other datasets to handle regs
unconditional_clip_image_embeds = []
datasets = get_dataloader_datasets(self.data_loader)
for i in range(len(datasets)):
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
if len(unconditional_clip_image_embeds) == 0:
raise ValueError("No unconditional clip image embeds found. This should not happen")
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
if self.train_config.negative_prompt is not None:
raise ValueError("Negative prompt is not supported on MultiGPUSDTrainer")
# setup the unified training model
self.unified_training_model = UnifiedTrainingModel(
sd=self.sd,
network=self.network,
adapter=self.adapter,
assistant_adapter=self.assistant_adapter,
train_config=self.train_config,
adapter_config=self.adapter_config,
embedding=self.embedding,
timer=self.timer,
trigger_word=self.trigger_word,
)
# call parent hook
super().hook_before_train_loop()
# you can expand these in a child class to make customization easier
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return self.unified_training_model.preprocess_batch(batch)
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def end_of_training_loop(self):
pass
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.optimizer.zero_grad(set_to_none=True)
loss = self.unified_training_model(batch)
if torch.isnan(loss):
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
with (network):
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
# if self.is_bfloat:
# loss.backward()
# else:
if not self.do_grad_scale:
loss.backward()
else:
self.scaler.scale(loss).backward()
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if self.do_grad_scale:
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# self.optimizer.step()
if not self.do_grad_scale:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()
else:
# gradient accumulation. Just a place for breakpoint
pass
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('restore_adapter'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.adapter.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
)
self.end_of_training_loop()
return loss_dict

View File

@@ -19,6 +19,23 @@ class SDTrainerExtension(Extension):
return SDTrainer
# This is for generic training (LoRA, Dreambooth, FineTuning)
class MultiGPUSDTrainerExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "trainer_v2"
# name is the name of the extension for printing
name = "Trainer V2"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .TrainerV2 import TrainerV2
return TrainerV2
# for backwards compatability
class TextualInversionTrainer(SDTrainerExtension):
uid = "textual_inversion_trainer"
@@ -26,5 +43,5 @@ class TextualInversionTrainer(SDTrainerExtension):
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
SDTrainerExtension, TextualInversionTrainer
SDTrainerExtension, TextualInversionTrainer, MultiGPUSDTrainerExtension
]

File diff suppressed because it is too large Load Diff