From e8583860adbe236609e91a238ab4a4191c5d2ca8 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 11 Sep 2023 14:46:06 -0600 Subject: [PATCH] Upgraded to dev for t2i on diffusers. Minor migrations to make it work. --- .../advanced_generator/ReferenceGenerator.py | 193 ++++++++++++++++++ .../advanced_generator/__init__.py | 25 +++ .../config/train.example.yaml | 91 +++++++++ requirements.txt | 5 +- toolkit/config_modules.py | 2 +- toolkit/data_loader.py | 45 +++- toolkit/network_mixins.py | 7 +- 7 files changed, 356 insertions(+), 12 deletions(-) create mode 100644 extensions_built_in/advanced_generator/ReferenceGenerator.py create mode 100644 extensions_built_in/advanced_generator/__init__.py create mode 100644 extensions_built_in/advanced_generator/config/train.example.yaml diff --git a/extensions_built_in/advanced_generator/ReferenceGenerator.py b/extensions_built_in/advanced_generator/ReferenceGenerator.py new file mode 100644 index 00000000..14ae1ff5 --- /dev/null +++ b/extensions_built_in/advanced_generator/ReferenceGenerator.py @@ -0,0 +1,193 @@ +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLAdapterPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + + +class ReferenceGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def run(self): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + if self.generate_config.t2i_adapter_path is not None: + self.adapter = T2IAdapter.from_pretrained( + "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16" + ).to(device) + + midas_depth = MidasDetector.from_pretrained( + "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" + ).to(device) + + pipe = StableDiffusionXLAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + adapter=self.adapter, + ).to(device) + pipe.set_progress_bar_config(disable=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png') + + caption = batch.get_caption_list()[0] + + img: torch.Tensor = batch.tensor.clone() + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + + width, height = image.size + min_res = min(width, height) + + if self.generate_config.walk_seed: + seed = seed + 1 + + if self.generate_config.seed == -1: + # random + seed = random.randint(0, 1000000) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # generate depth map + image = midas_depth( + image, + detect_resolution=min_res, # do 512 ? + image_resolution=min_res + ) + + # image.save(output_depth_path) + + gen_images = pipe( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale, + guidance_scale=self.generate_config.guidance_scale, + ).images[0] + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/__init__.py b/extensions_built_in/advanced_generator/__init__.py new file mode 100644 index 00000000..d811fe89 --- /dev/null +++ b/extensions_built_in/advanced_generator/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class AdvancedReferenceGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "reference_generator" + + # name is the name of the extension for printing + name = "Reference Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ReferenceGenerator import ReferenceGenerator + return ReferenceGenerator + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + AdvancedReferenceGeneratorExtension, +] diff --git a/extensions_built_in/advanced_generator/config/train.example.yaml b/extensions_built_in/advanced_generator/config/train.example.yaml new file mode 100644 index 00000000..793d5d55 --- /dev/null +++ b/extensions_built_in/advanced_generator/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/requirements.txt b/requirements.txt index 8c848127..45a304ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch torchvision safetensors -diffusers +git+https://github.com/huggingface/diffusers.git transformers lycoris_lora flatten_json @@ -19,4 +19,5 @@ omegaconf k-diffusion open_clip_torch timm -prodigyopt \ No newline at end of file +prodigyopt +controlnet_aux==0.0.7 \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 50faac60..bb1b321e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -77,7 +77,7 @@ class TrainConfig: self.optimizer_params = kwargs.get('optimizer_params', {}) self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) - self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50) + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 682caceb..33a9161c 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -12,6 +12,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A +from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO @@ -268,6 +269,37 @@ class PairedImageDataset(Dataset): img1 = exif_transpose(Image.open(img_path)).convert('RGB') img_path = img_path_or_tuple[1] img2 = exif_transpose(Image.open(img_path)).convert('RGB') + + # always use # 2 (pos) + bucket_resolution = get_bucket_for_image_size( + width=img2.width, + height=img2.height, + resolution=self.size + ) + + # images will be same base dimension, but may be trimmed. We need to shrink and then central crop + if bucket_resolution['width'] > bucket_resolution['height']: + img1_scale_to_height = bucket_resolution["height"] + img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height)) + img2_scale_to_height = bucket_resolution["height"] + img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height)) + else: + img1_scale_to_width = bucket_resolution["width"] + img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width)) + img2_scale_to_width = bucket_resolution["width"] + img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width)) + + img1_crop_height = bucket_resolution["height"] + img1_crop_width = bucket_resolution["width"] + img2_crop_height = bucket_resolution["height"] + img2_crop_width = bucket_resolution["width"] + + # scale then center crop images + img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC) + img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1) + img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC) + img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2) + # combine them side by side img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) img.paste(img1, (0, 0)) @@ -275,15 +307,14 @@ class PairedImageDataset(Dataset): else: img_path = img_path_or_tuple img = exif_transpose(Image.open(img_path)).convert('RGB') + height = self.size + # determine width to keep aspect ratio + width = int(img.size[0] * height / img.size[1]) + + # Downscale the source image first + img = img.resize((width, height), Image.BICUBIC) prompt = self.get_prompt_item(index) - - height = self.size - # determine width to keep aspect ratio - width = int(img.size[0] * height / img.size[1]) - - # Downscale the source image first - img = img.resize((width, height), Image.BICUBIC) img = self.transform(img) return img, prompt, (self.neg_weight, self.pos_weight) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 7955642c..08b316f5 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -122,11 +122,14 @@ class ToolkitModuleMixin: return lx * scale - def forward(self: Module, x): + # this may get an additional positional arg or not + + def forward(self: Module, x, *args, **kwargs): + # diffusers added scale to resnet.. not sure what it does if self._multiplier is None: self.set_multiplier(0.0) - org_forwarded = self.org_forward(x) + org_forwarded = self.org_forward(x, *args, **kwargs) lora_output = self._call_forward(x) multiplier = self._multiplier.clone().detach()