From 7909b50d2462968f04a9b223c5c8e71f56365e9f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 14 Oct 2023 08:44:53 -0600 Subject: [PATCH] Added adapter assistance to SD training --- extensions_built_in/sd_trainer/SDTrainer.py | 173 +++++++++++--------- toolkit/config_modules.py | 3 + 2 files changed, 97 insertions(+), 79 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 33338292..bc78b369 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,10 +1,12 @@ import os.path from collections import OrderedDict +from typing import Union from PIL import Image from diffusers import T2IAdapter from torch.utils.data import DataLoader +from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds @@ -30,10 +32,26 @@ class SDTrainer(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', None] def before_model_load(self): pass + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16" + ).to(self.device_torch) + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + + def hook_before_train_loop(self): # move vae to device if we did not cache latents if not self.is_latents_cached: @@ -44,53 +62,6 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() - def get_adapter_images(self, batch: 'DataLoaderBatchDTO'): - if self.adapter_config.image_dir is None: - # adapter needs 0 to 1 values, batch is -1 to 1 - adapter_batch = batch.tensor.clone().to( - self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) - ) - adapter_batch = (adapter_batch + 1) / 2 - return adapter_batch - img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] - adapter_folder_path = self.adapter_config.image_dir - adapter_images = [] - # loop through images - for file_item in batch.file_items: - img_path = file_item.path - file_name_no_ext = os.path.basename(img_path).split('.')[0] - # find the image - for ext in img_ext_list: - if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): - adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) - break - width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height - adapter_tensors = [] - # load images with torch transforms - for idx, adapter_image in enumerate(adapter_images): - # we need to centrally crop the largest dimension of the image to match the batch shape after scaling - # to the smallest dimension - img: Image.Image = Image.open(adapter_image) - if img.width > img.height: - # scale down so height is the same as batch - new_height = height - new_width = int(img.width * (height / img.height)) - else: - new_width = width - new_height = int(img.height * (width / img.width)) - - img = img.resize((new_width, new_height)) - crop_fn = transforms.CenterCrop((height, width)) - # crop the center to match batch - img = crop_fn(img) - img = adapter_transforms(img) - adapter_tensors.append(img) - - # stack them - adapter_tensors = torch.stack(adapter_tensors).to( - self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) - ) - return adapter_tensors def hook_train_loop(self, batch): @@ -98,18 +69,21 @@ class SDTrainer(BaseSDTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() + + has_adapter_img = batch.control_tensor is not None + self.timer.stop('preprocess_batch') with torch.no_grad(): adapter_images = None sigmas = None - if self.adapter: + if has_adapter_img and (self.adapter or self.assistant_adapter): with self.timer('get_adapter_images'): # todo move this to data loader if batch.control_tensor is not None: adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() else: - adapter_images = self.get_adapter_images(batch) + raise NotImplementedError("Adapter images now must be loaded with dataloader") # not 100% sure what this does. But they do it here # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) @@ -128,9 +102,36 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif self.train_config.match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.5 + # adapter_strength_max = 0.8 + adapter_strength_min = 0.9 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + # flush() with self.timer('grad_setup'): - self.optimizer.zero_grad() # text encoding grad_on_text_encoder = False @@ -148,6 +149,7 @@ class SDTrainer(BaseSDTrainProcess): # set the weights network.multiplier = network_weight_list + self.optimizer.zero_grad(set_to_none=True) # activate network if it exits with network: @@ -159,15 +161,44 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = conditional_embeds.detach() # flush() pred_kwargs = {} - if self.adapter and isinstance(self.adapter, T2IAdapter): - with self.timer('encode_adapter'): - down_block_additional_residuals = self.adapter(adapter_images) - down_block_additional_residuals = [ - sample.to(dtype=dtype) for sample in down_block_additional_residuals - ] - pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.adapter if self.adapter else self.assistant_adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals + ] - if self.adapter and isinstance(self.adapter, IPAdapter): + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + control_pred = None + if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist: + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + # dont use network on this + network.multiplier = 0.0 + control_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + control_pred = control_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + del pred_kwargs['down_block_additional_residuals'] + # restore network + network.multiplier = network_weight_list + + if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): with torch.no_grad(): conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) @@ -183,29 +214,13 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs ) - # if self.adapter: - # # todo, diffusers does this on t2i training, is it better approach? - # # Denoise the latents - # denoised_latents = noise_pred * (-sigmas) + noisy_latents - # weighing = sigmas ** -2.0 - # - # # Get the target for loss depending on the prediction type - # if self.sd.noise_scheduler.config.prediction_type == "epsilon": - # target = batch.latents # we are computing loss against denoise latents - # elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": - # target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps) - # else: - # raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") - # - # # MSE loss - # loss = torch.mean( - # (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1), - # dim=1, - # ) - # else: with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() - if self.sd.prediction_type == 'v_prediction': + + if control_pred is not None: + # matching adapter prediction + target = control_pred + elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) else: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 8a616805..be0e6725 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -121,6 +121,9 @@ class TrainConfig: self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) self.start_step = kwargs.get('start_step', None) self.free_u = kwargs.get('free_u', False) + self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) + self.match_adapter_assist = kwargs.get('match_adapter_assist', False) + class ModelConfig: