From 3f3636b788eaec342d18c0b74fb50634a3eebe4c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Jun 2024 06:24:20 -0600 Subject: [PATCH] Bug fixes and little improvements here and there. --- extensions_built_in/sd_trainer/SDTrainer.py | 85 +++++--- jobs/process/BaseSDTrainProcess.py | 45 ++++- jobs/process/TrainVAEProcess.py | 103 +++++++--- requirements.txt | 3 +- testing/test_bucket_dataloader.py | 185 +++++++++++++----- toolkit/config_modules.py | 6 +- toolkit/dataloader_mixins.py | 3 +- toolkit/guidance.py | 4 +- toolkit/optimizer.py | 4 +- toolkit/stable_diffusion_model.py | 32 ++- toolkit/style.py | 2 +- toolkit/util/adafactor_stochastic_rounding.py | 3 +- 12 files changed, 358 insertions(+), 117 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index d08c9e9..c8f55f1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -4,7 +4,7 @@ from collections import OrderedDict from typing import Union, Literal, List, Optional import numpy as np -from diffusers import T2IAdapter, AutoencoderTiny +from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel import torch.functional as F from safetensors.torch import load_file @@ -824,6 +824,10 @@ class SDTrainer(BaseSDTrainProcess): # remove the residuals as we wont use them on prediction when matching control if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] if can_disable_adapter: self.adapter.is_active = was_adapter_active @@ -1065,7 +1069,7 @@ class SDTrainer(BaseSDTrainProcess): # if prompt_2 is not None: # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] - with network: + with (network): # encode clip adapter here so embeds are active for tokenizer if self.adapter and isinstance(self.adapter, ClipVisionAdapter): with self.timer('encode_clip_vision_embeds'): @@ -1162,26 +1166,27 @@ class SDTrainer(BaseSDTrainProcess): # flush() pred_kwargs = {} - if has_adapter_img and ( - (self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): - with torch.set_grad_enabled(self.adapter is not None): - adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter - adapter_multiplier = get_adapter_multiplier() - with self.timer('encode_adapter'): - down_block_additional_residuals = adapter(adapter_images) - if self.assistant_adapter: - # not training. detach - down_block_additional_residuals = [ - sample.to(dtype=dtype).detach() * adapter_multiplier for sample in - down_block_additional_residuals - ] - else: - down_block_additional_residuals = [ - sample.to(dtype=dtype) * adapter_multiplier for sample in - down_block_additional_residuals - ] - pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): @@ -1362,6 +1367,32 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.do_cfg: self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True) + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later @@ -1423,10 +1454,10 @@ class SDTrainer(BaseSDTrainProcess): # 0.0 for the backward pass and the gradients will be 0.0 # I spent weeks on fighting this. DON'T DO IT # with fsdp_overlap_step_with_backward(): - if self.is_bfloat: - loss.backward() - else: - self.scaler.scale(loss).backward() + # if self.is_bfloat: + loss.backward() + # else: + # self.scaler.scale(loss).backward() # flush() if not self.is_grad_accumulation_step: @@ -1443,8 +1474,8 @@ class SDTrainer(BaseSDTrainProcess): self.optimizer.step() else: # apply gradients - self.scaler.step(self.optimizer) - self.scaler.update() + self.optimizer.step() + # self.scaler.update() # self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) else: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 9adf05d..8567477 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -10,7 +10,7 @@ from typing import Union, List, Optional import numpy as np import yaml -from diffusers import T2IAdapter +from diffusers import T2IAdapter, ControlNetModel from safetensors.torch import save_file, load_file # from lycoris.config import PRESET from torch.utils.data import DataLoader @@ -143,7 +143,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None self.embedding: Union[Embedding, None] = None is_training_adapter = self.adapter_config is not None and self.adapter_config.train @@ -368,6 +368,7 @@ class BaseSDTrainProcess(BaseTrainProcess): pass def save(self, step=None): + flush() if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -423,6 +424,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # add _lora to name if self.adapter_config.type == 't2i': adapter_name += '_t2i' + elif self.adapter_config.type == 'control_net': + adapter_name += '_cn' elif self.adapter_config.type == 'clip': adapter_name += '_clip' elif self.adapter_config.type.startswith('ip'): @@ -441,6 +444,23 @@ class BaseSDTrainProcess(BaseTrainProcess): meta=save_meta, dtype=get_torch_dtype(self.save_config.dtype) ) + elif self.adapter_config.type == 'control_net': + # save in diffusers format + name_or_path = file_path.replace('.safetensors', '') + # move it to the new dtype and cpu + orig_device = self.adapter.device + orig_dtype = self.adapter.dtype + self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype)) + self.adapter.save_pretrained( + name_or_path, + dtype=get_torch_dtype(self.save_config.dtype), + safe_serialization=True + ) + meta_path = os.path.join(name_or_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(self.meta, f) + # move it back + self.adapter = self.adapter.to(orig_device, dtype=orig_dtype) else: save_ip_adapter_from_diffusers( state_dict, @@ -551,6 +571,8 @@ class BaseSDTrainProcess(BaseTrainProcess): paths = [p for p in paths if '_refiner' not in p] if '_t2i' not in name: paths = [p for p in paths if '_t2i' not in p] + if '_cn' not in name: + paths = [p for p in paths if '_cn' not in p] if len(paths) > 0: latest_path = max(paths, key=os.path.getctime) @@ -956,8 +978,11 @@ class BaseSDTrainProcess(BaseTrainProcess): def setup_adapter(self): # t2i adapter is_t2i = self.adapter_config.type == 't2i' + is_control_net = self.adapter_config.type == 'control_net' if self.adapter_config.type == 't2i': suffix = 't2i' + elif self.adapter_config.type == 'control_net': + suffix = 'cn' elif self.adapter_config.type == 'clip': suffix = 'clip' elif self.adapter_config.type == 'reference': @@ -990,6 +1015,16 @@ class BaseSDTrainProcess(BaseTrainProcess): downscale_factor=self.adapter_config.downscale_factor, adapter_type=self.adapter_config.adapter_type, ) + elif is_control_net: + if self.adapter_config.name_or_path is None: + raise ValueError("ControlNet requires a name_or_path to load from currently") + load_from_path = self.adapter_config.name_or_path + if latest_save_path is not None: + load_from_path = latest_save_path + self.adapter = ControlNetModel.from_pretrained( + load_from_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + ) elif self.adapter_config.type == 'clip': self.adapter = ClipVisionAdapter( sd=self.sd, @@ -1013,7 +1048,7 @@ class BaseSDTrainProcess(BaseTrainProcess): adapter_config=self.adapter_config, ) self.adapter.to(self.device_torch, dtype=dtype) - if latest_save_path is not None: + if latest_save_path is not None and not is_control_net: # load adapter from path print(f"Loading adapter from {latest_save_path}") if is_t2i: @@ -1040,8 +1075,8 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype=dtype ) self.adapter.load_state_dict(loaded_state_dict) - if self.adapter_config.train: - self.load_training_state_from_metadata(latest_save_path) + if latest_save_path is not None and self.adapter_config.train: + self.load_training_state_from_metadata(latest_save_path) # set trainable params self.sd.adapter = self.adapter diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 00b098c..fb6536c 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -1,6 +1,7 @@ import copy import glob import os +import shutil import time from collections import OrderedDict @@ -13,6 +14,7 @@ from torch import nn from torchvision.transforms import transforms from jobs.process import BaseTrainProcess +from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss @@ -25,6 +27,8 @@ from tqdm import tqdm import time import numpy as np from .models.vgg19_critic import Critic +from torchvision.transforms import Resize +import lpips IMAGE_TRANSFORMS = transforms.Compose( [ @@ -62,6 +66,7 @@ class TrainVAEProcess(BaseTrainProcess): self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) self.optimizer_params = self.get_conf('optimizer_params', {}) @@ -71,6 +76,9 @@ class TrainVAEProcess(BaseTrainProcess): self.vgg_19 = None self.style_weight_scalers = [] self.content_weight_scalers = [] + self.lpips_loss:lpips.LPIPS = None + + self.vae_scale_factor = 8 self.step_num = 0 self.epoch_num = 0 @@ -137,6 +145,15 @@ class TrainVAEProcess(BaseTrainProcess): num_workers=6 ) + def remove_oldest_checkpoint(self): + max_to_keep = 4 + folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) + if len(folders) > max_to_keep: + folders.sort(key=os.path.getmtime) + for folder in folders[:-max_to_keep]: + print(f"Removing {folder}") + shutil.rmtree(folder) + def setup_vgg19(self): if self.vgg_19 is None: self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( @@ -211,7 +228,7 @@ class TrainVAEProcess(BaseTrainProcess): def get_pattern_loss(self, pred, target): if self._pattern_loss is None: - self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, + self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype) loss = torch.mean(self._pattern_loss(pred, target)) return loss @@ -226,25 +243,21 @@ class TrainVAEProcess(BaseTrainProcess): step_num = f"_{str(step).zfill(9)}" self.update_training_metadata() - filename = f'{self.job.name}{step_num}.safetensors' - # prepare meta - save_meta = get_meta_for_safetensors(self.meta, self.job.name) + filename = f'{self.job.name}{step_num}_diffusers' - state_dict = convert_diffusers_back_to_ldm(self.vae) - - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(torch.float32) - state_dict[key] = v - - # having issues with meta - save_file(state_dict, os.path.join(self.save_root, filename), save_meta) + self.vae = self.vae.to("cpu", dtype=torch.float16) + self.vae.save_pretrained( + save_directory=os.path.join(self.save_root, filename) + ) + self.vae = self.vae.to(self.device, dtype=self.torch_dtype) self.print(f"Saved to {os.path.join(self.save_root, filename)}") if self.use_critic: self.critic.save(step) + self.remove_oldest_checkpoint() + def sample(self, step=None): sample_folder = os.path.join(self.save_root, 'samples') if not os.path.exists(sample_folder): @@ -280,6 +293,13 @@ class TrainVAEProcess(BaseTrainProcess): output_img.paste(input_img, (0, 0)) output_img.paste(decoded, (self.resolution, 0)) + scale_up = 2 + if output_img.height <= 300: + scale_up = 4 + + # scale up using nearest neighbor + output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST) + step_num = '' if step is not None: # zero-pad 9 digits @@ -294,7 +314,7 @@ class TrainVAEProcess(BaseTrainProcess): path_to_load = self.vae_path # see if we have a checkpoint in out output to resume from self.print(f"Looking for latest checkpoint in {self.save_root}") - files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors")) + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) if files and len(files) > 0: latest_file = max(files, key=os.path.getmtime) print(f" - Latest checkpoint is: {latest_file}") @@ -306,13 +326,14 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f"Loading VAE") self.print(f" - Loading VAE: {path_to_load}") if self.vae is None: - self.vae = load_vae(path_to_load, dtype=self.torch_dtype) + self.vae = AutoencoderKL.from_pretrained(path_to_load) # set decoder to train self.vae.to(self.device, dtype=self.torch_dtype) self.vae.requires_grad_(False) self.vae.eval() self.vae.decoder.train() + self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) def run(self): super().run() @@ -374,6 +395,10 @@ class TrainVAEProcess(BaseTrainProcess): if self.use_critic: self.critic.setup() + if self.lpips_weight > 0 and self.lpips_loss is None: + # self.lpips_loss = lpips.LPIPS(net='vgg') + self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype) + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, optimizer_params=self.optimizer_params) @@ -397,6 +422,7 @@ class TrainVAEProcess(BaseTrainProcess): self.sample() blank_losses = OrderedDict({ "total": [], + "lpips": [], "style": [], "content": [], "mse": [], @@ -415,17 +441,29 @@ class TrainVAEProcess(BaseTrainProcess): for batch in self.data_loader: if self.step_num >= self.max_steps: break + with torch.no_grad(): - batch = batch.to(self.device, dtype=self.torch_dtype) + batch = batch.to(self.device, dtype=self.torch_dtype) - # forward pass - dgd = self.vae.encode(batch).latent_dist - mu, logvar = dgd.mean, dgd.logvar - latents = dgd.sample() - latents.requires_grad_(True) + # resize so it matches size of vae evenly + if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: + batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor, + batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) + + # forward pass + dgd = self.vae.encode(batch).latent_dist + mu, logvar = dgd.mean, dgd.logvar + latents = dgd.sample() + latents.detach().requires_grad_(True) pred = self.vae.decode(latents).sample + with torch.no_grad(): + show_tensors( + pred.clamp(-1, 1).clone(), + "combined tensor" + ) + # Run through VGG19 if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: stacked = torch.cat([pred, batch], dim=0) @@ -441,14 +479,31 @@ class TrainVAEProcess(BaseTrainProcess): content_loss = self.get_content_loss() * self.content_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight + if self.lpips_weight > 0: + lpips_loss = self.lpips_loss( + pred.clamp(-1, 1), + batch.clamp(-1, 1) + ).mean() * self.lpips_weight + else: + lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight if self.use_critic: critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + + # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it + if self.lpips_weight > 0: + max_target = lpips_loss.abs() * 0.1 + with torch.no_grad(): + crit_g_scaler = 1.0 + if critic_gen_loss.abs() > max_target: + crit_g_scaler = max_target / critic_gen_loss.abs() + + critic_gen_loss *= crit_g_scaler else: critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss # Backward pass and optimization optimizer.zero_grad() @@ -460,6 +515,8 @@ class TrainVAEProcess(BaseTrainProcess): loss_value = loss.item() # get exponent like 3.54e-4 loss_string = f"loss: {loss_value:.2e}" + if self.lpips_weight > 0: + loss_string += f" lpips: {lpips_loss.item():.2e}" if self.content_weight > 0: loss_string += f" cnt: {content_loss.item():.2e}" if self.style_weight > 0: @@ -496,6 +553,7 @@ class TrainVAEProcess(BaseTrainProcess): self.progress_bar.update(1) epoch_losses["total"].append(loss_value) + epoch_losses["lpips"].append(lpips_loss.item()) epoch_losses["style"].append(style_loss.item()) epoch_losses["content"].append(content_loss.item()) epoch_losses["mse"].append(mse_loss.item()) @@ -506,6 +564,7 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["crD"].append(critic_d_loss) log_losses["total"].append(loss_value) + log_losses["lpips"].append(lpips_loss.item()) log_losses["style"].append(style_loss.item()) log_losses["content"].append(content_loss.item()) log_losses["mse"].append(mse_loss.item()) diff --git a/requirements.txt b/requirements.txt index 6caa30b..19cdbcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,5 @@ controlnet_aux==0.0.7 python-dotenv bitsandbytes xformers -hf_transfer \ No newline at end of file +hf_transfer +lpips \ No newline at end of file diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index f830517..9d56b01 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -7,10 +7,11 @@ from torchvision import transforms import sys import os import cv2 +import random sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from toolkit.paths import SD_SCRIPTS_ROOT - +import torchvision.transforms.functional from toolkit.image_utils import show_img sys.path.append(SD_SCRIPTS_ROOT) @@ -32,7 +33,7 @@ parser.add_argument('--epochs', type=int, default=1) args = parser.parse_args() dataset_folder = args.dataset_folder -resolution = 1024 +resolution = 512 bucket_tolerance = 64 batch_size = 1 @@ -41,6 +42,7 @@ batch_size = 1 dataset_config = DatasetConfig( dataset_path=dataset_folder, + control_path=dataset_folder, resolution=resolution, # caption_ext='json', default_caption='default', @@ -48,62 +50,135 @@ dataset_config = DatasetConfig( buckets=True, bucket_tolerance=bucket_tolerance, # poi='person', - shuffle_augmentations=True, + # shuffle_augmentations=True, # augmentations=[ # { - # 'method': 'GaussianBlur', - # 'blur_limit': (1, 16), - # 'sigma_limit': (0, 8), - # 'p': 0.8 - # }, - # { - # 'method': 'ImageCompression', - # 'quality_lower': 10, - # 'quality_upper': 100, - # 'compression_type': 0, - # 'p': 0.8 - # }, - # { - # 'method': 'ImageCompression', - # 'quality_lower': 20, - # 'quality_upper': 100, - # 'compression_type': 1, - # 'p': 0.8 - # }, - # { - # 'method': 'RingingOvershoot', - # 'blur_limit': (3, 35), - # 'cutoff': (0.7, 1.96), - # 'p': 0.8 - # }, - # { - # 'method': 'GaussNoise', - # 'var_limit': (0, 300), - # 'per_channel': True, - # 'mean': 0.0, - # 'p': 0.8 - # }, - # { - # 'method': 'GlassBlur', - # 'sigma': 0.6, - # 'max_delta': 7, - # 'iterations': 2, - # 'mode': 'fast', - # 'p': 0.8 - # }, - # { - # 'method': 'Downscale', - # 'scale_max': 0.5, - # 'interpolation': 'cv2.INTER_CUBIC', - # 'p': 0.8 + # 'method': 'Posterize', + # 'num_bits': [(0, 4), (0, 4), (0, 4)], + # 'p': 1.0 # }, + # # ] - - ) dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) +def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5): + if random.random() < p: + kernel_size = random.randint(min_kernel_size, max_kernel_size) + # make sure it is odd + if kernel_size % 2 == 0: + kernel_size += 1 + img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size) + return img + +def quantize(image, palette): + """ + Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient. + Only works for one image i.e. CHW. Does NOT work for batches. + ref https://discuss.pytorch.org/t/color-quantization/104528/4 + """ + + orig_dtype = image.dtype + + C, H, W = image.shape + n_colors = palette.shape[0] + + # Easier to work with list of colors + flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C] + + # Repeat image so that there are n_color number of columns of the same image + flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C] + + # Get euclidean distance between each pixel in each column and the column's respective color + # i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc. + squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2 + euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors] + + # Get the shortest distance (one value per row (H*W) is selected) + min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W] + + # Create a mask for the closest colors + one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors] + + # Multiply the mask with the palette colors to get the quantized image + quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C] + + # Reshape it back to the original input format. + quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W] + + return quantized_img.to(orig_dtype) + + + +def color_block_imgs(img, neg1_1=False): + # expects values 0 - 1 + orig_dtype = img.dtype + if neg1_1: + img = img * 0.5 + 0.5 + + img = img * 255 + img = img.clamp(0, 255) + img = img.to(torch.uint8) + + img_chunks = torch.chunk(img, img.shape[0], dim=0) + + posterized_chunks = [] + + for chunk in img_chunks: + img_size = (chunk.shape[2] + chunk.shape[3]) // 2 + # min kernel size of 1% of image, max 10% + min_kernel_size = int(img_size * 0.01) + max_kernel_size = int(img_size * 0.1) + + # blur first + chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8) + num_colors = random.randint(1, 16) + + resize_to = 16 + # chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use) + + # mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))] + + # shrink the image down to num_colors x num_colors + shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to]) + + mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))] + + colors = shrunk.view(3, -1).T + # remove duplicates + colors = torch.unique(colors, dim=0) + colors = colors.numpy() + colors = colors.tolist() + + use_colors = [random.choice(colors) for _ in range(num_colors)] + + pallette = torch.tensor([ + [0, 0, 0], + mean_color, + [255, 255, 255], + ] + use_colors, dtype=torch.float32) + chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0) + + # chunk = torchvision.transforms.functional.equalize(chunk) + # color jitter + if random.random() < 0.5: + chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5)) + if random.random() < 0.5: + chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0)) + # if random.random() < 0.5: + # chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5)) + chunk = random_blur(chunk, p=0.6) + posterized_chunks.append(chunk) + + img = torch.cat(posterized_chunks, dim=0) + img = img.to(orig_dtype) + img = img / 255 + + if neg1_1: + img = img * 2 - 1 + return img + # run through an epoch ang check sizes dataloader_iterator = iter(dataloader) @@ -112,11 +187,19 @@ for epoch in range(args.epochs): batch: 'DataLoaderBatchDTO' img_batch = batch.tensor + img_batch = color_block_imgs(img_batch, neg1_1=True) + chunks = torch.chunk(img_batch, batch_size, dim=0) # put them so they are size by side big_img = torch.cat(chunks, dim=3) big_img = big_img.squeeze(0) + control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0) + big_control_img = torch.cat(control_chunks, dim=3) + big_control_img = big_control_img.squeeze(0) * 2 - 1 + + big_img = torch.cat([big_img, big_control_img], dim=2) + min_val = big_img.min() max_val = big_img.max() @@ -127,7 +210,7 @@ for epoch in range(args.epochs): show_img(img) - # time.sleep(1.0) + time.sleep(1.0) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a6b9a9f..4d29131 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -129,13 +129,13 @@ class NetworkConfig: self.conv = 4 -AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker'] +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net'] CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] class AdapterConfig: def __init__(self, **kwargs): - self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net self.in_channels: int = kwargs.get('in_channels', 3) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) @@ -530,6 +530,8 @@ class DatasetConfig: self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) self.extra_values: List[float] = kwargs.get('extra_values', []) self.square_crop: bool = kwargs.get('square_crop', False) + # apply same augmentations to control images. Usually want this true unless special case + self.replay_transforms: bool = kwargs.get('replay_transforms', True) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 6e7b13b..045107a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -860,7 +860,8 @@ class AugmentationFileItemDTOMixin: # only store the spatial transforms augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms] - self.aug_replay_spatial_transforms = augmented_params + if self.dataset_config.replay_transforms: + self.aug_replay_spatial_transforms = augmented_params # convert back to RGB tensor augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 13ab65e..ba69c3e 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -240,7 +240,7 @@ def get_direct_guidance_loss( noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) - guidance_scale = 1.0 + guidance_scale = 1.25 guidance_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) @@ -586,6 +586,8 @@ def get_guided_tnt( loss = prior_loss + this_loss - that_loss + loss = loss.mean() + loss.backward() # detach it so parent class can run backward on no grads without throwing error diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index f61fea8..d2a1e92 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -1,5 +1,5 @@ import torch -from transformers import Adafactor +from transformers import Adafactor, AdamW def get_optimizer( @@ -69,7 +69,7 @@ def get_optimizer( if 'relative_step' not in optimizer_params: optimizer_params['relative_step'] = False if 'scale_parameter' not in optimizer_params: - optimizer_params['scale_parameter'] = True + optimizer_params['scale_parameter'] = False if 'warmup_init' not in optimizer_params: optimizer_params['warmup_init'] = False optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index aaf04cd..3990c13 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -39,7 +39,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ - StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline import diffusers from diffusers import \ AutoencoderKL, \ @@ -497,6 +498,12 @@ class StableDiffusion: else: Pipe = StableDiffusionAdapterPipeline extra_args['adapter'] = self.adapter + elif isinstance(self.adapter, ControlNetModel): + if self.is_xl: + Pipe = StableDiffusionXLControlNetPipeline + else: + Pipe = StableDiffusionControlNetPipeline + extra_args['controlnet'] = self.adapter elif isinstance(self.adapter, ReferenceAdapter): # pass the noise scheduler to the adapter self.adapter.noise_scheduler = noise_scheduler @@ -588,6 +595,10 @@ class StableDiffusion: validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) extra['image'] = validation_image extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): transform = transforms.Compose([ transforms.ToTensor(), @@ -967,6 +978,16 @@ class StableDiffusion: if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) + def scale_model_input(model_input, timestep_tensor): if is_input_scaled: return model_input @@ -1383,11 +1404,13 @@ class StableDiffusion: # move to device and dtype image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list] + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + # resize images if not divisible by 8 for i in range(len(image_list)): image = image_list[i] - if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0: - image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image) + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) images = torch.stack(image_list) if isinstance(self.vae, AutoencoderTiny): @@ -1756,6 +1779,9 @@ class StableDiffusion: elif isinstance(self.adapter, T2IAdapter): requires_grad = self.adapter.adapter.conv_in.weight.requires_grad adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device elif isinstance(self.adapter, ClipVisionAdapter): requires_grad = self.adapter.embedder.training adapter_device = self.adapter.device diff --git a/toolkit/style.py b/toolkit/style.py index b08214a..26ac33f 100644 --- a/toolkit/style.py +++ b/toolkit/style.py @@ -158,7 +158,7 @@ def get_style_model_and_losses( ): # content_layers = ['conv_4'] # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] - content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2'] + content_layers = ['conv2_2', 'conv3_2', 'conv4_2'] style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval() # set all weights in the model to our dtype diff --git a/toolkit/util/adafactor_stochastic_rounding.py b/toolkit/util/adafactor_stochastic_rounding.py index 0e65a32..c993032 100644 --- a/toolkit/util/adafactor_stochastic_rounding.py +++ b/toolkit/util/adafactor_stochastic_rounding.py @@ -81,7 +81,8 @@ def step_adafactor(self, closure=None): lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"] + update = (grad ** 2) + eps if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"]