From 66c6f0f6f77fd1a2bb226dff423359a24fdcd8c6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 3 Aug 2023 14:51:25 -0600 Subject: [PATCH] Big refactor of SD runner and added image generator --- README.md | 15 + config/examples/generate.example.yaml | 60 +++ config/examples/train_slider.example.yml | 3 +- jobs/GenerateJob.py | 32 ++ jobs/__init__.py | 1 + jobs/process/BaseProcess.py | 3 +- jobs/process/BaseSDTrainProcess.py | 443 +++-------------------- jobs/process/GenerateProcess.py | 102 ++++++ jobs/process/TrainSDRescaleProcess.py | 8 +- jobs/process/TrainSliderProcess.py | 23 +- jobs/process/TrainSliderProcessOld.py | 8 +- jobs/process/__init__.py | 1 + toolkit/config_modules.py | 203 ++++++++++- toolkit/job.py | 3 + toolkit/scheduler.py | 33 ++ toolkit/stable_diffusion_model.py | 415 ++++++++++++++++++++- 16 files changed, 923 insertions(+), 430 deletions(-) create mode 100644 config/examples/generate.example.yaml create mode 100644 jobs/GenerateJob.py create mode 100644 jobs/process/GenerateProcess.py create mode 100644 toolkit/scheduler.py diff --git a/README.md b/README.md index 25e6db17..50f7ee8f 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,16 @@ here so far. --- +### Batch Image Generation + +A image generator that can take frompts from a config file or form a txt file and generate them to a +folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used +for generat batch image generation. +It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`. +Mere info is in the comments in the example + +--- + ### LoRA (lierla), LoCON (LyCORIS) extractor It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features @@ -143,6 +153,11 @@ Just went in and out. It is much worse on smaller faces than shown here. ## Change Log +#### 2021-08-03 +Another big refactor to make SD more modular. + +Made batch image generation script + #### 2021-08-01 Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable diff --git a/config/examples/generate.example.yaml b/config/examples/generate.example.yaml new file mode 100644 index 00000000..1a3e19ef --- /dev/null +++ b/config/examples/generate.example.yaml @@ -0,0 +1,60 @@ +--- + +job: generate # tells the runner what to do +config: + name: "generate" # this is not really used anywhere currently but required by runner + process: + # process 1 + - type: to_folder # process images to a folder + output_folder: "output/gen" + device: cuda:0 # cpu, cuda:0, etc + generate: + # these are your defaults you can override most of them with flags + sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now + width: 1024 + height: 1024 + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: -1 # -1 is random + guidance_scale: 7 + sample_steps: 20 + ext: ".png" # .png, .jpg, .jpeg, .webp + + # here ate the flags you can use for prompts. Always start with + # your prompt first then add these flags after. You can use as many + # like + # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20 + # we will try to support all sd-scripts flags where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + + prompt_file: false # if true a txt file will be created next to images with prompt strings used + # prompts can also be a path to a text file with one prompt per line + # prompts: "/path/to/prompts.txt" + prompts: + - "photo of batman" + - "photo of superman" + - "photo of spiderman" + - "photo of a superhero --n batman superman spiderman" + + model: + # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt + # name_or_path: "runwayml/stable-diffusion-v1-5" + name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" + is_v2: false # for v2 models + is_v_pred: false # for v-prediction models (most v2 models) + is_xl: false # for SDXL models + dtype: bf16 diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index 84b9eca6..be796a09 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -57,7 +57,8 @@ config: # bf16 works best if your GPU supports it (modern) dtype: bf16 # fp32, bf16, fp16 # if you have it, use it. It is faster and better - xformers: true + # torch 2.0 doesnt need xformers anymore, only use if you have lower version +# xformers: true # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX # although, the way we train sliders is comparative, so it probably won't work anyway noise_offset: 0.0 diff --git a/jobs/GenerateJob.py b/jobs/GenerateJob.py new file mode 100644 index 00000000..5bab1142 --- /dev/null +++ b/jobs/GenerateJob.py @@ -0,0 +1,32 @@ +from jobs import BaseJob +from collections import OrderedDict +from typing import List +from jobs.process import GenerateProcess +from toolkit.paths import REPOS_ROOT + +import sys + +sys.path.append(REPOS_ROOT) + +process_dict = { + 'to_folder': 'GenerateProcess', +} + + +class GenerateJob(BaseJob): + process: List[GenerateProcess] + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index f00a20dc..9f232a52 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -3,3 +3,4 @@ from .ExtractJob import ExtractJob from .TrainJob import TrainJob from .MergeJob import MergeJob from .ModJob import ModJob +from .GenerateJob import GenerateJob diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index 821eebaa..167cdc14 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -1,10 +1,9 @@ import copy import json from collections import OrderedDict -from typing import ForwardRef -class BaseProcess: +class BaseProcess(object): meta: OrderedDict def __init__( diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 7b0b14a1..afde7db7 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,34 +1,23 @@ import glob -import time from collections import OrderedDict import os -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg - from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer -from toolkit.paths import REPOS_ROOT -import sys -from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline - -sys.path.append(REPOS_ROOT) -sys.path.append(os.path.join(REPOS_ROOT, 'leco')) - -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler, PNDMScheduler, \ - DDIMScheduler, DDPMScheduler +from toolkit.scheduler import get_lr_scheduler +from toolkit.stable_diffusion_model import StableDiffusion from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta -from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from toolkit.train_tools import get_torch_dtype import gc import torch from tqdm import tqdm -from leco import train_util, model_util -from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig -from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ + GenerateImageConfig def flush(): @@ -36,11 +25,9 @@ def flush(): gc.collect() -UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 -VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 - - class BaseSDTrainProcess(BaseTrainProcess): + sd: StableDiffusion + def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) self.custom_pipeline = custom_pipeline @@ -64,177 +51,52 @@ class BaseSDTrainProcess(BaseTrainProcess): self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None - self.sd: 'StableDiffusion' = None - # sdxl stuff - self.logit_scale = None - self.ckppt_info = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.train_config.dtype, + custom_pipeline=self.custom_pipeline, + ) - # added later + # to hold network if there is one self.network = None def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') - if not os.path.exists(sample_folder): - os.makedirs(sample_folder, exist_ok=True) - - if self.network is not None: - self.network.eval() - - # save current seed state for training - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - original_device_dict = { - 'vae': self.sd.vae.device, - 'unet': self.sd.unet.device, - # 'tokenizer': self.sd.tokenizer.device, - } - - # handle sdxl text encoder - if isinstance(self.sd.text_encoder, list): - for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))): - original_device_dict[f'text_encoder_{i}'] = encoder.device - encoder.to(self.device_torch) - else: - original_device_dict['text_encoder'] = self.sd.text_encoder.device - self.sd.text_encoder.to(self.device_torch) - - self.sd.vae.to(self.device_torch) - self.sd.unet.to(self.device_torch) - # self.sd.text_encoder.to(self.device_torch) - # self.sd.tokenizer.to(self.device_torch) - # TODO add clip skip - if self.sd.is_xl: - pipeline = StableDiffusionXLPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder[0], - text_encoder_2=self.sd.text_encoder[1], - tokenizer=self.sd.tokenizer[0], - tokenizer_2=self.sd.tokenizer[1], - scheduler=self.sd.noise_scheduler, - ).to(self.device_torch) - else: - pipeline = StableDiffusionPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder, - tokenizer=self.sd.tokenizer, - scheduler=self.sd.noise_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ).to(self.device_torch) - # disable progress bar - pipeline.set_progress_bar_config(disable=True) + gen_img_config_list = [] sample_config = self.first_sample_config if is_first else self.sample_config - start_seed = sample_config.seed - start_multiplier = self.network.multiplier current_seed = start_seed + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i - pipeline.to(self.device_torch) - with self.network: - with torch.no_grad(): - if self.network is not None: - assert self.network.is_active - if self.logging_config.verbose: - print("network_state", { - 'is_active': self.network.is_active, - 'multiplier': self.network.multiplier, - }) + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" - for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}", - leave=False): - raw_prompt = sample_config.prompts[i] + filename = f"[time]_{step_num}_[count].png" - neg = sample_config.neg - multiplier = sample_config.network_multiplier - p_split = raw_prompt.split('--') - prompt = p_split[0].strip() - height = sample_config.height - width = sample_config.width + output_path = os.path.join(sample_folder, filename) - if len(p_split) > 1: - for split in p_split: - flag = split[:1] - content = split[1:].strip() - if flag == 'n': - neg = content - elif flag == 'm': - # multiplier - multiplier = float(content) - elif flag == 'w': - # multiplier - width = int(content) - elif flag == 'h': - # multiplier - height = int(content) + gen_img_config_list.append(GenerateImageConfig( + prompt=sample_config.prompts[i], # it will autoparse the prompt + width=sample_config.width, + height=sample_config.height, + negative_prompt=sample_config.neg, + seed=current_seed, + guidance_scale=sample_config.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_config.sample_steps, + network_multiplier=sample_config.network_multiplier, + output_path=output_path, + )) - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - - if sample_config.walk_seed: - current_seed += i - - if self.network is not None: - self.network.multiplier = multiplier - torch.manual_seed(current_seed) - torch.cuda.manual_seed(current_seed) - - if self.sd.is_xl: - img = pipeline( - prompt, - height=height, - width=width, - num_inference_steps=sample_config.sample_steps, - guidance_scale=sample_config.guidance_scale, - negative_prompt=neg, - guidance_rescale=0.7, - ).images[0] - else: - img = pipeline( - prompt, - height=height, - width=width, - num_inference_steps=sample_config.sample_steps, - guidance_scale=sample_config.guidance_scale, - negative_prompt=neg, - ).images[0] - - step_num = '' - if step is not None: - # zero-pad 9 digits - step_num = f"_{str(step).zfill(9)}" - seconds_since_epoch = int(time.time()) - # zero-pad 2 digits - i_str = str(i).zfill(2) - filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" - output_path = os.path.join(sample_folder, filename) - img.save(output_path) - - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() - - # restore training state - torch.set_rng_state(rng_state) - if cuda_rng_state is not None: - torch.cuda.set_rng_state(cuda_rng_state) - - self.sd.vae.to(original_device_dict['vae']) - self.sd.unet.to(original_device_dict['unet']) - if isinstance(self.sd.text_encoder, list): - for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))): - encoder.to(original_device_dict[f'text_encoder_{i}']) - else: - self.sd.text_encoder.to(original_device_dict['text_encoder']) - if self.network is not None: - self.network.train() - self.network.multiplier = start_multiplier - # self.sd.tokenizer.to(original_device_dict['tokenizer']) + # send to be generated + self.sd.generate_images(gen_img_config_list) def update_training_metadata(self): o_dict = OrderedDict({ @@ -328,148 +190,10 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass - def get_latent_noise( - self, - height=None, - width=None, - pixel_height=None, - pixel_width=None, - ): - if height is None and pixel_height is None: - raise ValueError("height or pixel_height must be specified") - if width is None and pixel_width is None: - raise ValueError("width or pixel_width must be specified") - if height is None: - height = pixel_height // VAE_SCALE_FACTOR - if width is None: - width = pixel_width // VAE_SCALE_FACTOR - - noise = torch.randn( - ( - self.train_config.batch_size, - UNET_IN_CHANNELS, - height, - width, - ), - device="cpu", - ) - noise = apply_noise_offset(noise, self.train_config.noise_offset) - return noise - def hook_train_loop(self): # return loss return 0.0 - def get_time_ids_from_latents(self, latents): - bs, ch, h, w = list(latents.shape) - - height = h * VAE_SCALE_FACTOR - width = w * VAE_SCALE_FACTOR - - dtype = get_torch_dtype(self.train_config.dtype) - - if self.sd.is_xl: - prompt_ids = train_util.get_add_time_ids( - height, - width, - dynamic_crops=False, # look into this - dtype=dtype, - ).to(self.device_torch, dtype=dtype) - return train_util.concat_embeddings( - prompt_ids, prompt_ids, bs - ) - else: - return None - - def predict_noise( - self, - latents: torch.FloatTensor, - text_embeddings: PromptEmbeds, - timestep: int, - guidance_scale=7.5, - guidance_rescale=0, # 0.7 - add_time_ids=None, - **kwargs, - ): - - if self.sd.is_xl: - if add_time_ids is None: - add_time_ids = self.get_time_ids_from_latents(latents) - - latent_model_input = torch.cat([latents] * 2) - - latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep) - - added_cond_kwargs = { - "text_embeds": text_embeddings.pooled_embeds, - "time_ids": add_time_ids, - } - - # predict the noise residual - noise_pred = self.sd.unet( - latent_model_input, - timestep, - encoder_hidden_states=text_embeddings.text_embeds, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - else: - # if we are doing classifier free guidance, need to double up - latent_model_input = torch.cat([latents] * 2) - - latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep) - - # predict the noise residual - noise_pred = self.sd.unet( - latent_model_input, - timestep, - encoder_hidden_states=text_embeddings.text_embeds, - ).sample - - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - return noise_pred - - # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 - def diffuse_some_steps( - self, - latents: torch.FloatTensor, - text_embeddings: PromptEmbeds, - total_timesteps: int = 1000, - start_timesteps=0, - guidance_scale=1, - add_time_ids=None, - **kwargs, - ): - - for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): - noise_pred = self.predict_noise( - latents, - text_embeddings, - timestep, - guidance_scale=guidance_scale, - add_time_ids=add_time_ids, - **kwargs, - ) - latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample - - # return latents_steps - return latents - def get_latest_save_path(self): # get latest saved step if os.path.exists(self.save_root): @@ -497,92 +221,33 @@ class BaseSDTrainProcess(BaseTrainProcess): print("load_weights not implemented for non-network models") def run(self): - super().run() - + # run base process run + BaseTrainProcess.run(self) ### HOOK ### self.hook_before_model_load() + # run base sd process run + self.sd.load_model() dtype = get_torch_dtype(self.train_config.dtype) - # TODO handle other schedulers - # sch = KDPM2DiscreteScheduler - sch = DDPMScheduler - # do our own scheduler - prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" - scheduler = sch( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.0120, - beta_schedule="scaled_linear", - clip_sample=False, - prediction_type=prediction_type, - ) - if self.model_config.is_xl: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = CustomStableDiffusionXLPipeline - pipe = pipln.from_single_file( - self.model_config.name_or_path, - dtype=dtype, - scheduler_type='ddpm', - device=self.device_torch, - ).to(self.device_torch) + # model is loaded from BaseSDProcess + unet = self.sd.unet + vae = self.sd.vae + tokenizer = self.sd.tokenizer + text_encoder = self.sd.text_encoder + noise_scheduler = self.sd.noise_scheduler - text_encoders = [pipe.text_encoder, pipe.text_encoder_2] - tokenizer = [pipe.tokenizer, pipe.tokenizer_2] - for text_encoder in text_encoders: - text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - text_encoder = text_encoders - else: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = CustomStableDiffusionPipeline - pipe = pipln.from_single_file( - self.model_config.name_or_path, - dtype=dtype, - scheduler_type='dpm', - device=self.device_torch, - load_safety_checker=False, - ).to(self.device_torch) - pipe.register_to_config(requires_safety_checker=False) - text_encoder = pipe.text_encoder - text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - tokenizer = pipe.tokenizer - - # scheduler doesn't get set sometimes, so we set it here - pipe.scheduler = scheduler - - unet = pipe.unet - noise_scheduler = pipe.scheduler - vae = pipe.vae.to('cpu', dtype=dtype) - vae.eval() - vae.requires_grad_(False) - flush() - - self.sd = StableDiffusion( - vae, - tokenizer, - text_encoder, - unet, - noise_scheduler, - is_xl=self.model_config.is_xl, - pipeline=pipe - ) - - unet.to(self.device_torch, dtype=dtype) if self.train_config.xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() if self.train_config.gradient_checkpointing: unet.enable_gradient_checkpointing() + unet.to(self.device_torch, dtype=dtype) unet.requires_grad_(False) unet.eval() + vae = vae.to(torch.device('cpu'), dtype=dtype) + vae.requires_grad_(False) + vae.eval() if self.network_config is not None: self.network = LoRASpecialNetwork( @@ -598,6 +263,8 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.network.force_to(self.device_torch, dtype=dtype) + # give network to sd so it can use it + self.sd.network = self.network self.network.apply_to( text_encoder, @@ -650,7 +317,7 @@ class BaseSDTrainProcess(BaseTrainProcess): optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer - lr_scheduler = train_util.get_lr_scheduler( + lr_scheduler = get_lr_scheduler( self.train_config.lr_scheduler, optimizer, max_iterations=self.train_config.steps, diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py new file mode 100644 index 00000000..ff017839 --- /dev/null +++ b/jobs/process/GenerateProcess.py @@ -0,0 +1,102 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef, List + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype + + +class GenerateConfig: + prompts: List[str] + + def __init__(self, **kwargs): + self.sampler = kwargs.get('sampler', 'ddpm') + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.prompt_file = kwargs.get('prompt_file', False) + if self.prompts is None: + raise ValueError("Prompts must be set") + if isinstance(self.prompts, str): + if os.path.exists(self.prompts): + with open(self.prompts, 'r') as f: + self.prompts = f.read().splitlines() + self.prompts = [p.strip() for p in self.prompts if len(p.strip()) > 0] + else: + raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") + + +class GenerateProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + sd: StableDiffusion + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.device = self.get_conf('device', self.job.device) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.model_config.dtype, + ) + print(f"Using device {self.device}") + + def run(self): + super().run() + print("Loading model...") + self.sd.load_model() + + print(f"Generating {len(self.generate_config.prompts)} images") + # build prompt image configs + prompt_image_configs = [] + for prompt in self.generate_config.prompts: + prompt_image_configs.append(GenerateImageConfig( + prompt=prompt, + prompt_2=self.generate_config.prompt_2, + width=self.generate_config.width, + height=self.generate_config.height, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + negative_prompt=self.generate_config.neg, + negative_prompt_2=self.generate_config.neg_2, + seed=self.generate_config.seed, + guidance_rescale=self.generate_config.guidance_rescale, + output_ext=self.generate_config.ext, + output_folder=self.output_folder, + add_prompt_file=self.generate_config.prompt_file + )) + # generate images + self.sd.generate_images(prompt_image_configs) + + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index a0ed1855..e2c4d77c 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -202,9 +202,11 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): ) # get noise - noise = self.get_latent_noise( + noise = self.sd.get_latent_noise( pixel_height=self.rescale_config.from_resolution, pixel_width=self.rescale_config.from_resolution, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) torch.set_default_device(self.device_torch) @@ -238,7 +240,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): ) with torch.no_grad(): - noise_pred_target = self.predict_noise( + noise_pred_target = self.sd.predict_noise( latents, text_embeddings=text_embeddings, timestep=timestep, @@ -256,7 +258,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): with self.network: assert self.network.is_active self.network.multiplier = 1.0 - noise_pred_train = self.predict_noise( + noise_pred_train = self.sd.predict_noise( reduced_latents, text_embeddings=text_embeddings, timestep=timestep, diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index e618d9c9..cb012e60 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -1,7 +1,6 @@ # ref: # - https://github.com/p1atdev/LECO/blob/main/train_lora.py import random -import time from collections import OrderedDict import os from typing import Optional @@ -14,16 +13,12 @@ from toolkit.paths import REPOS_ROOT import sys from toolkit.stable_diffusion_model import PromptEmbeds - -sys.path.append(REPOS_ROOT) -sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from toolkit.train_tools import get_torch_dtype import gc from toolkit import train_tools import torch -from leco import train_util, model_util -from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion +from .BaseSDTrainProcess import BaseSDTrainProcess class ACTION_TYPES_SLIDER: @@ -131,7 +126,6 @@ class TrainSliderProcess(BaseSDTrainProcess): self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") - if not self.slider_config.prompt_tensors: # shuffle random.shuffle(self.prompt_txt_list) @@ -175,8 +169,8 @@ class TrainSliderProcess(BaseSDTrainProcess): for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): for target in self.slider_config.targets: prompt_list = [ - f"{target.target_class}", # target_class - f"{target.target_class} {neutral}", # target_class with neutral + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral f"{target.positive}", # positive_target f"{target.positive} {neutral}", # positive_target with neutral f"{target.negative}", # negative_target @@ -320,7 +314,6 @@ class TrainSliderProcess(BaseSDTrainProcess): ) ] - # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling # if text encoder is list @@ -364,7 +357,7 @@ class TrainSliderProcess(BaseSDTrainProcess): loss_function = torch.nn.MSELoss() def get_noise_pred(neg, pos, gs, cts, dn): - return self.predict_noise( + return self.sd.predict_noise( latents=dn, text_embeddings=train_tools.concat_prompt_embeddings( neg, # negative prompt @@ -391,9 +384,11 @@ class TrainSliderProcess(BaseSDTrainProcess): ).item() # get noise - noise = self.get_latent_noise( + noise = self.sd.get_latent_noise( pixel_height=height, pixel_width=width, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) # get latents @@ -403,7 +398,7 @@ class TrainSliderProcess(BaseSDTrainProcess): with self.network: assert self.network.is_active self.network.multiplier = multiplier * rand_weight - denoised_latents = self.diffuse_some_steps( + denoised_latents = self.sd.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( prompt_pair.positive_target, # unconditional diff --git a/jobs/process/TrainSliderProcessOld.py b/jobs/process/TrainSliderProcessOld.py index a33f6314..9a673c46 100644 --- a/jobs/process/TrainSliderProcessOld.py +++ b/jobs/process/TrainSliderProcessOld.py @@ -245,7 +245,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess): loss_function = torch.nn.MSELoss() def get_noise_pred(p, n, gs, cts, dn): - return self.predict_noise( + return self.sd.predict_noise( latents=dn, text_embeddings=train_tools.concat_prompt_embeddings( p, # unconditional @@ -272,9 +272,11 @@ class TrainSliderProcessOld(BaseSDTrainProcess): ).item() # get noise - noise = self.get_latent_noise( + noise = self.sd.get_latent_noise( pixel_height=height, pixel_width=width, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) # get latents @@ -284,7 +286,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess): with self.network: assert self.network.is_active self.network.multiplier = multiplier - denoised_latents = self.diffuse_some_steps( + denoised_latents = self.sd.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( positive, # unconditional diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index e58e0069..a227a7a7 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -10,3 +10,4 @@ from .TrainSliderProcessOld import TrainSliderProcessOld from .TrainLoRAHack import TrainLoRAHack from .TrainSDRescaleProcess import TrainSDRescaleProcess from .ModRescaleLoraProcess import ModRescaleLoraProcess +from .GenerateProcess import GenerateProcess diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3848461f..16861140 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,4 +1,7 @@ -from typing import List +import os +import time +from typing import List, Optional +import random class SaveConfig: @@ -27,6 +30,7 @@ class SampleConfig: self.guidance_scale = kwargs.get('guidance_scale', 7) self.sample_steps = kwargs.get('sample_steps', 20) self.network_multiplier = kwargs.get('network_multiplier', 1) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) class NetworkConfig: @@ -35,7 +39,7 @@ class NetworkConfig: rank = kwargs.get('rank', None) linear = kwargs.get('linear', None) if rank is not None: - self.rank: int = rank # rank for backward compatibility + self.rank: int = rank # rank for backward compatibility self.linear: int = rank elif linear is not None: self.rank: int = linear @@ -71,6 +75,7 @@ class ModelConfig: self.is_v2: bool = kwargs.get('is_v2', False) self.is_xl: bool = kwargs.get('is_xl', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False) + self.dtype: str = kwargs.get('dtype', 'float16') if self.name_or_path is None: raise ValueError('name_or_path must be specified') @@ -103,3 +108,197 @@ class SliderConfig: self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None) + + +class GenerateImageConfig: + def __init__( + self, + prompt: str = '', + prompt_2: Optional[str] = None, + width: int = 512, + height: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str = '', + negative_prompt_2: Optional[str] = None, + seed: int = -1, + network_multiplier: float = 1.0, + guidance_rescale: float = 0.0, + # the tag [time] will be replaced with milliseconds since epoch + output_path: str = None, # full image path + output_folder: str = None, # folder to save image in if output_path is not specified + output_ext: str = 'png', # extension to save image as if output_path is not specified + output_tail: str = '', # tail to add to output filename + add_prompt_file: bool = False, # add a prompt file with generated image + ): + self.width: int = width + self.height: int = height + self.num_inference_steps: int = num_inference_steps + self.guidance_scale: float = guidance_scale + self.guidance_rescale: float = guidance_rescale + self.prompt: str = prompt + self.prompt_2: str = prompt_2 + self.negative_prompt: str = negative_prompt + self.negative_prompt_2: str = negative_prompt_2 + + self.output_path: str = output_path + self.seed: int = seed + if self.seed == -1: + # generate random one + self.seed = random.randint(0, 2 ** 32 - 1) + self.network_multiplier: float = network_multiplier + self.output_folder: str = output_folder + self.output_ext: str = output_ext + self.add_prompt_file: bool = add_prompt_file + self.output_tail: str = output_tail + self.gen_time: int = int(time.time() * 1000) + + # prompt string will override any settings above + self._process_prompt_string() + + # handle dual text encoder prompts if nothing passed + if negative_prompt_2 is None: + self.negative_prompt_2 = negative_prompt + + if prompt_2 is None: + self.prompt_2 = prompt + + # parse prompt paths + if self.output_path is None and self.output_folder is None: + raise ValueError('output_path or output_folder must be specified') + elif self.output_path is not None: + self.output_folder = os.path.dirname(self.output_path) + self.output_ext = os.path.splitext(self.output_path)[1][1:] + self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0] + + else: + self.output_filename_no_ext = '[time]_[count]' + if len(self.output_tail) > 0: + self.output_filename_no_ext += '_' + self.output_tail + self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext) + + # adjust height + self.height = max(64, self.height - self.height % 8) # round to divisible by 8 + self.width = max(64, self.width - self.width % 8) # round to divisible by 8 + + def set_gen_time(self, gen_time: int = None): + if gen_time is not None: + self.gen_time = gen_time + else: + self.gen_time = int(time.time() * 1000) + + def _get_path_no_ext(self, count: int = 0, max_count=0): + # zero pad count + count_str = str(count).zfill(len(str(max_count))) + # replace [time] with gen time + filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time)) + # replace [count] with count + filename = filename.replace('[count]', count_str) + return filename + + def get_image_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + filename += '.' + self.output_ext + # join with folder + return os.path.join(self.output_folder, filename) + + def get_prompt_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + filename += '.txt' + # join with folder + return os.path.join(self.output_folder, filename) + + def save_image(self, image, count: int = 0, max_count=0): + # make parent dirs + os.makedirs(self.output_folder, exist_ok=True) + self.set_gen_time() + # TODO save image gen header info for A1111 and us, our seeds probably wont match + image.save(self.get_image_path(count, max_count)) + # do prompt file + if self.add_prompt_file: + self.save_prompt_file(count, max_count) + + def save_prompt_file(self, count: int = 0, max_count=0): + # save prompt file + with open(self.get_prompt_path(count, max_count), 'w') as f: + prompt = self.prompt + if self.prompt_2 is not None: + prompt += ' --p2 ' + self.prompt_2 + if self.negative_prompt is not None: + prompt += ' --n ' + self.negative_prompt + if self.negative_prompt_2 is not None: + prompt += ' --n2 ' + self.negative_prompt_2 + prompt += ' --w ' + str(self.width) + prompt += ' --h ' + str(self.height) + prompt += ' --seed ' + str(self.seed) + prompt += ' --cfg ' + str(self.guidance_scale) + prompt += ' --steps ' + str(self.num_inference_steps) + prompt += ' --m ' + str(self.network_multiplier) + prompt += ' --gr ' + str(self.guidance_rescale) + + # get gen info + f.write(self.prompt) + + def _process_prompt_string(self): + # we will try to support all sd-scripts where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --m Specify the network multiplier for the generated image. + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + # --network_multiplier Specify the network multiplier for the generated image same as --m + + # process prompt string and update values if it has some + if self.prompt is not None and len(self.prompt) > 0: + # process prompt string + prompt = self.prompt + prompt = prompt.strip() + p_split = prompt.split('--') + self.prompt = p_split[0].strip() + + if len(p_split) > 1: + for split in p_split[1:]: + # allows multi char flags + flag = split.split(' ')[0].strip() + content = split[len(flag):].strip() + if flag == 'p2': + self.prompt_2 = content + elif flag == 'n': + self.negative_prompt = content + elif flag == 'n2': + self.negative_prompt_2 = content + elif flag == 'w': + self.width = int(content) + elif flag == 'h': + self.height = int(content) + elif flag == 'd': + self.seed = int(content) + elif flag == 'seed': + self.seed = int(content) + elif flag == 'l': + self.guidance_scale = float(content) + elif flag == 'cfg': + self.guidance_scale = float(content) + elif flag == 's': + self.num_inference_steps = int(content) + elif flag == 'steps': + self.num_inference_steps = int(content) + elif flag == 'm': + self.network_multiplier = float(content) + elif flag == 'network_multiplier': + self.network_multiplier = float(content) + elif flag == 'gr': + self.guidance_rescale = float(content) diff --git a/toolkit/job.py b/toolkit/job.py index c0b1a191..60752740 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -16,6 +16,9 @@ def get_job(config_path, name=None): if job == 'mod': from jobs import ModJob return ModJob(config) + if job == 'generate': + from jobs import GenerateJob + return GenerateJob(config) # elif job == 'train': # from jobs import TrainJob diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py new file mode 100644 index 00000000..ab5558a5 --- /dev/null +++ b/toolkit/scheduler.py @@ -0,0 +1,33 @@ +import torch +from typing import Optional + + +def get_lr_scheduler( + name: Optional[str], + optimizer: torch.optim.Optimizer, + max_iterations: Optional[int], + lr_min: Optional[float], + **kwargs, +): + if name == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs + ) + elif name == "cosine_with_restarts": + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs + ) + elif name == "step": + return torch.optim.lr_scheduler.StepLR( + optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs + ) + elif name == "constant": + return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) + elif name == "linear": + return torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs + ) + else: + raise ValueError( + "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" + ) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8b465249..5fb253d6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,12 +1,16 @@ +import gc import typing -from typing import Union, OrderedDict +from typing import Union, OrderedDict, List import sys import os +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file +from tqdm import tqdm +from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.paths import REPOS_ROOT -from toolkit.train_tools import get_torch_dtype +from toolkit.train_tools import get_torch_dtype, apply_noise_offset sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) @@ -14,6 +18,32 @@ from leco import train_util import torch from library import model_util from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl +from diffusers.schedulers import DDPMScheduler +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline + + +class BlankNetwork: + multiplier = 1.0 + is_active = True + + def __init__(self): + pass + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 class PromptEmbeds: @@ -39,31 +69,382 @@ class PromptEmbeds: # if is type checking if typing.TYPE_CHECKING: - from diffusers import StableDiffusionPipeline - from toolkit.pipelines import CustomStableDiffusionXLPipeline + from diffusers import \ + StableDiffusionPipeline, \ + AutoencoderKL, \ + UNet2DConditionModel + from diffusers.schedulers import KarrasDiffusionSchedulers + from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection class StableDiffusion: pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] + vae: Union[None, 'AutoencoderKL'] + unet: Union[None, 'UNet2DConditionModel'] + text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] + device: str + dtype: str + torch_dtype: torch.dtype + device_torch: torch.device + model_config: ModelConfig def __init__( self, - vae, - tokenizer, - text_encoder, - unet, - noise_scheduler, - is_xl=False, - pipeline=None, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None ): - # text encoder has a list of 2 for xl - self.vae = vae + self.custom_pipeline = custom_pipeline + self.device = device + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = torch.device(self.device) + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + + # to hold network if there is one + self.network = None + self.is_xl = model_config.is_xl + self.is_v2 = model_config.is_v2 + + def load_model(self): + dtype = get_torch_dtype(self.dtype) + + # TODO handle other schedulers + # sch = KDPM2DiscreteScheduler + sch = DDPMScheduler + # do our own scheduler + prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + scheduler = sch( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + clip_sample=False, + prediction_type=prediction_type, + steps_offset=1 + ) + if self.model_config.is_xl: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = CustomStableDiffusionXLPipeline + pipe = pipln.from_single_file( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='ddpm', + device=self.device_torch, + ).to(self.device_torch) + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + else: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = CustomStableDiffusionPipeline + pipe = pipln.from_single_file( + self.model_config.name_or_path, + dtype=dtype, + scheduler_type='dpm', + device=self.device_torch, + load_safety_checker=False, + ).to(self.device_torch) + pipe.register_to_config(requires_safety_checker=False) + text_encoder = pipe.text_encoder + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + tokenizer = pipe.tokenizer + + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = scheduler + + self.unet = pipe.unet + self.noise_scheduler = pipe.scheduler + self.vae = pipe.vae.to(self.device_torch, dtype=dtype) + self.vae.eval() + self.vae.requires_grad_(False) + self.unet.to(self.device_torch, dtype=dtype) + self.unet.requires_grad_(False) + self.unet.eval() + self.tokenizer = tokenizer self.text_encoder = text_encoder - self.unet = unet - self.noise_scheduler = noise_scheduler - self.is_xl = is_xl - self.pipeline = pipeline + self.pipeline = pipe + + def generate_images(self, image_configs: List[GenerateImageConfig]): + # sample_folder = os.path.join(self.save_root, 'samples') + if self.network is not None: + self.network.eval() + network = self.network + else: + network = BlankNetwork() + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + original_device_dict = { + 'vae': self.vae.device, + 'unet': self.unet.device, + # 'tokenizer': self.tokenizer.device, + } + + # handle sdxl text encoder + if isinstance(self.text_encoder, list): + for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))): + original_device_dict[f'text_encoder_{i}'] = encoder.device + encoder.to(self.device_torch) + else: + original_device_dict['text_encoder'] = self.text_encoder.device + self.text_encoder.to(self.device_torch) + + self.vae.to(self.device_torch) + self.unet.to(self.device_torch) + + # TODO add clip skip + if self.is_xl: + pipeline = StableDiffusionXLPipeline( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=self.noise_scheduler, + add_watermarker=False, + ).to(self.device_torch) + # force turn that (ruin your images with obvious green and red dots) the #$@@ off!!! + pipeline.watermark = None + else: + pipeline = StableDiffusionPipeline( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=self.noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ).to(self.device_torch) + # disable progress bar + pipeline.set_progress_bar_config(disable=True) + + start_multiplier = 1.0 + if self.network is not None: + start_multiplier = self.network.multiplier + + pipeline.to(self.device_torch) + with network: + with torch.no_grad(): + if self.network is not None: + assert self.network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + if self.network is not None: + self.network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + if self.is_xl: + img = pipeline( + prompt=gen_config.prompt, + prompt_2=gen_config.prompt_2, + negative_prompt=gen_config.negative_prompt, + negative_prompt_2=gen_config.negative_prompt_2, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=gen_config.guidance_rescale, + ).images[0] + else: + img = pipeline( + prompt=gen_config.prompt, + negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + ).images[0] + + gen_config.save_image(img) + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.vae.to(original_device_dict['vae']) + self.unet.to(original_device_dict['unet']) + if isinstance(self.text_encoder, list): + for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))): + encoder.to(original_device_dict[f'text_encoder_{i}']) + else: + self.text_encoder.to(original_device_dict['text_encoder']) + if self.network is not None: + self.network.train() + self.network.multiplier = start_multiplier + # self.tokenizer.to(original_device_dict['tokenizer']) + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + noise = torch.randn( + ( + batch_size, + UNET_IN_CHANNELS, + height, + width, + ), + device="cpu", + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_time_ids_from_latents(self, latents: torch.Tensor): + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + + if self.is_xl: + prompt_ids = train_util.get_add_time_ids( + height, + width, + dynamic_crops=False, # look into this + dtype=dtype, + ).to(self.device_torch, dtype=dtype) + return train_util.concat_embeddings( + prompt_ids, prompt_ids, bs + ) + else: + return None + + def predict_noise( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + timestep: int, + guidance_scale=7.5, + guidance_rescale=0, # 0.7 + add_time_ids=None, + **kwargs, + ): + + if self.is_xl: + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + else: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return noise_pred + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + **kwargs, + ): + + for timestep in tqdm(self.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): + noise_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + **kwargs, + ) + latents = self.noise_scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: prompt = prompt