From 80aa2dbb80e0178e80b70e334b77709ce0260815 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 24 Jul 2024 04:13:41 -0600 Subject: [PATCH] New image generation img2img. various tweaks and fixes --- .../advanced_generator/Img2ImgGenerator.py | 188 ++++++++++++++++++ .../advanced_generator/__init__.py | 19 +- extensions_built_in/sd_trainer/SDTrainer.py | 1 + jobs/process/BaseSDTrainProcess.py | 18 +- jobs/process/GenerateProcess.py | 85 ++++---- toolkit/config_modules.py | 2 + toolkit/dataloader_mixins.py | 2 +- toolkit/guidance.py | 8 +- toolkit/sampler.py | 9 +- 9 files changed, 285 insertions(+), 47 deletions(-) create mode 100644 extensions_built_in/advanced_generator/Img2ImgGenerator.py diff --git a/extensions_built_in/advanced_generator/Img2ImgGenerator.py b/extensions_built_in/advanced_generator/Img2ImgGenerator.py new file mode 100644 index 00000000..2a5cfe3f --- /dev/null +++ b/extensions_built_in/advanced_generator/Img2ImgGenerator.py @@ -0,0 +1,188 @@ +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLImg2ImgPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.denoise_strength = kwargs.get('denoise_strength', 0.5) + self.trigger_word = kwargs.get('trigger_word', None) + + +class Img2ImgGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.copy_inputs_to = self.get_conf('copy_inputs_to', None) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def to_pil(self, img): + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + return image + + def run(self): + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + + if self.model_config.is_xl: + pipe = StableDiffusionXLImg2ImgPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + ).to(device, dtype=self.torch_dtype) + else: + raise NotImplementedError("Only XL models are supported") + pipe.set_progress_bar_config(disable=True) + + # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + img_filename = img_filename_no_ext + '.' + self.generate_config.ext + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + + if self.copy_inputs_to is not None: + output_inputs_path = os.path.join(self.copy_inputs_to, img_filename) + output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt') + else: + output_inputs_path = None + output_inputs_caption_path = None + + caption = batch.get_caption_list()[0] + if self.generate_config.trigger_word is not None: + caption = caption.replace('[trigger]', self.generate_config.trigger_word) + + img: torch.Tensor = batch.tensor.clone() + image = self.to_pil(img) + + + # image.save(output_depth_path) + pipe: StableDiffusionXLImg2ImgPipeline = pipe + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + strength=self.generate_config.denoise_strength, + ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + if output_inputs_path is not None: + os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True) + image.save(output_inputs_path) + with open(output_inputs_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/__init__.py b/extensions_built_in/advanced_generator/__init__.py index 94a91c6b..65811655 100644 --- a/extensions_built_in/advanced_generator/__init__.py +++ b/extensions_built_in/advanced_generator/__init__.py @@ -36,7 +36,24 @@ class PureLoraGenerator(Extension): return PureLoraGenerator +# This is for generic training (LoRA, Dreambooth, FineTuning) +class Img2ImgGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "batch_img2img" + + # name is the name of the extension for printing + name = "Img2ImgGeneratorExtension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .Img2ImgGenerator import Img2ImgGenerator + return Img2ImgGenerator + + AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - AdvancedReferenceGeneratorExtension, PureLoraGenerator + AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension ] diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index de8aefbc..5e562112 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -483,6 +483,7 @@ class SDTrainer(BaseSDTrainProcess): noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, + scaler=self.scaler, **kwargs ) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 846156b2..0cfe9608 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -739,6 +739,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # add to noise noise += noise_shift + # standardize the noise + std = noise.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + noise = noise * normalizer + return noise def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): @@ -975,14 +980,21 @@ class BaseSDTrainProcess(BaseTrainProcess): noise = noise * noise_multiplier - latents = latents * self.train_config.latent_multiplier + latent_multiplier = self.train_config.latent_multiplier + + # handle adaptive scaling mased on std + if self.train_config.adaptive_scaling_factor: + std = latents.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + latent_multiplier = normalizer + + latents = latents * latent_multiplier + batch.latents = latents # normalize latents to a mean of 0 and an std of 1 # mean_zero_latents = latents - latents.mean() # latents = mean_zero_latents / mean_zero_latents.std() - - if batch.unconditional_latents is not None: batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 4a9ff8c5..1061f65a 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -80,6 +80,7 @@ class GenerateProcess(BaseProcess): self.model_config = ModelConfig(**self.get_conf('model', required=True)) self.device = self.get_conf('device', self.job.device) self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16')) self.progress_bar = None self.sd = StableDiffusion( @@ -87,49 +88,57 @@ class GenerateProcess(BaseProcess): model_config=self.model_config, dtype=self.model_config.dtype, ) + print(f"Using device {self.device}") + def clean_prompt(self, prompt: str): + # remove any non alpha numeric characters or ,'" from prompt + return ''.join(e for e in prompt if e.isalnum() or e in ", '\"") + def run(self): - super().run() - print("Loading model...") - self.sd.load_model() + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + self.sd.pipeline.to(self.device, self.torch_dtype) - print("Compiling model...") - # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) - if self.generate_config.compile: - self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") + print("Compiling model...") + # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + if self.generate_config.compile: + self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") - print(f"Generating {len(self.generate_config.prompts)} images") - # build prompt image configs - prompt_image_configs = [] - for prompt in self.generate_config.prompts: - width = self.generate_config.width - height = self.generate_config.height + print(f"Generating {len(self.generate_config.prompts)} images") + # build prompt image configs + prompt_image_configs = [] + for prompt in self.generate_config.prompts: + width = self.generate_config.width + height = self.generate_config.height + prompt = self.clean_prompt(prompt) - if self.generate_config.size_list is not None: - # randomly select a size - width, height = random.choice(self.generate_config.size_list) + if self.generate_config.size_list is not None: + # randomly select a size + width, height = random.choice(self.generate_config.size_list) - prompt_image_configs.append(GenerateImageConfig( - prompt=prompt, - prompt_2=self.generate_config.prompt_2, - width=width, - height=height, - num_inference_steps=self.generate_config.sample_steps, - guidance_scale=self.generate_config.guidance_scale, - negative_prompt=self.generate_config.neg, - negative_prompt_2=self.generate_config.neg_2, - seed=self.generate_config.seed, - guidance_rescale=self.generate_config.guidance_rescale, - output_ext=self.generate_config.ext, - output_folder=self.output_folder, - add_prompt_file=self.generate_config.prompt_file - )) - # generate images - self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) + prompt_image_configs.append(GenerateImageConfig( + prompt=prompt, + prompt_2=self.generate_config.prompt_2, + width=width, + height=height, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + negative_prompt=self.generate_config.neg, + negative_prompt_2=self.generate_config.neg_2, + seed=self.generate_config.seed, + guidance_rescale=self.generate_config.guidance_rescale, + output_ext=self.generate_config.ext, + output_folder=self.output_folder, + add_prompt_file=self.generate_config.prompt_file + )) + # generate images + self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) - print("Done generating images") - # cleanup - del self.sd - gc.collect() - torch.cuda.empty_cache() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ce416093..ca7ee36e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -266,6 +266,8 @@ class TrainConfig: self.reg_weight = kwargs.get('reg_weight', 1.0) self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) + # automatically adapte the vae scaling based on the image norm + self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False) # dropout that happens before encoding. It functions independently per text encoder self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index b52dfc53..28d52676 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -387,7 +387,7 @@ class CaptionProcessingDTOMixin: # join back together caption = ', '.join(token_list) - caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + # caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) if self.dataset_config.random_triggers: num_triggers = self.dataset_config.random_triggers_max diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 9f6bc1e1..540de72c 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -407,6 +407,7 @@ def get_guided_loss_polarity( batch: 'DataLoaderBatchDTO', noise: torch.Tensor, sd: 'StableDiffusion', + scaler=None, **kwargs ): dtype = get_torch_dtype(sd.torch_dtype) @@ -473,7 +474,10 @@ def get_guided_loss_polarity( loss = loss.mean([1, 2, 3]) loss = loss.mean() - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() # detach it so parent class can run backward on no grads without throwing error loss = loss.detach() @@ -590,6 +594,7 @@ def get_guidance_loss( unconditional_embeds: Optional[PromptEmbeds] = None, mask_multiplier=None, prior_pred=None, + scaler=None, **kwargs ): # TODO add others and process individual batch items separately @@ -621,6 +626,7 @@ def get_guidance_loss( batch, noise, sd, + scaler=scaler, **kwargs ) elif guidance_type == "tnt": diff --git a/toolkit/sampler.py b/toolkit/sampler.py index 6d42d94f..f9b0311b 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -41,9 +41,12 @@ sd_config = { "prediction_type": "epsilon", "sample_max_value": 1.0, "set_alpha_to_one": False, - "skip_prk_steps": False, - "steps_offset": 1, - "timestep_spacing": "trailing", + # "skip_prk_steps": False, # for training + "skip_prk_steps": True, + # "steps_offset": 1, + "steps_offset": 0, + # "timestep_spacing": "trailing", # for training + "timestep_spacing": "leading", "trained_betas": None }