From 27f343fc0833ac7db5c9f826ea9bb5fc06a3f713 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 16 Sep 2023 08:30:38 -0600 Subject: [PATCH] Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps. --- extensions_built_in/sd_trainer/SDTrainer.py | 97 ++++++++++-- jobs/process/BaseSDTrainProcess.py | 160 ++++++++++++++------ requirements.txt | 2 +- toolkit/config_modules.py | 20 ++- toolkit/network_mixins.py | 8 +- toolkit/saving.py | 29 +++- toolkit/sd_device_states_presets.py | 21 ++- toolkit/stable_diffusion_model.py | 61 +++++++- 8 files changed, 314 insertions(+), 84 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 522e5279..4be9bd5c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,17 +1,27 @@ +import os.path from collections import OrderedDict + +from PIL import Image from torch.utils.data import DataLoader + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc import torch from jobs.process import BaseSDTrainProcess +from torchvision import transforms def flush(): torch.cuda.empty_cache() gc.collect() +adapter_transforms = transforms.Compose([ + transforms.PILToTensor(), +]) + class SDTrainer(BaseSDTrainProcess): @@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() + def get_adapter_images(self, batch: 'DataLoaderBatchDTO'): + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + adapter_folder_path = self.adapter_config.image_dir + adapter_images = [] + # loop through images + for file_item in batch.file_items: + img_path = file_item.path + file_name_no_ext = os.path.basename(img_path).split('.')[0] + # find the image + for ext in img_ext_list: + if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): + adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) + break + + adapter_tensors = [] + # load images with torch transforms + for adapter_image in adapter_images: + img = Image.open(adapter_image) + img = adapter_transforms(img) + adapter_tensors.append(img) + + # stack them + adapter_tensors = torch.stack(adapter_tensors) + return adapter_tensors + def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() + + adapter_images = None + sigmas = None + if self.adapter: + # todo move this to data loader + adapter_images = self.get_adapter_images(batch) + # not 100% sure what this does. But they do it here + # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + # flush() self.optimizer.zero_grad() @@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess): # detach the embeddings conditional_embeds = conditional_embeds.detach() # flush() + pred_kwargs = {} + if self.adapter: + down_block_additional_residuals = self.adapter(adapter_images) + down_block_additional_residuals = [ + sample.to(dtype=dtype) for sample in down_block_additional_residuals + ] + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), timestep=timesteps, guidance_scale=1.0, + **pred_kwargs ) - # flush() - # 9.18 gb - noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.adapter: + # todo, diffusers does this on t2i training, is it better approach? + # Denoise the latents + denoised_latents = noise_pred * (-sigmas) + noisy_latents + weighing = sigmas ** -2.0 - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = batch.latents # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # MSE loss + loss = torch.mean( + (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1), + dim=1, + ) else: - target = noise + noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + # TODO: I think the sigma method does not need this. Check + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5d35ffb3..5f0049af 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,7 @@ from collections import OrderedDict import os from typing import Union +from diffusers import T2IAdapter # from lycoris.config import PRESET from torch.utils.data import DataLoader import torch @@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT from toolkit.progress_bar import ToolkitProgressBar from toolkit.sampler import get_sampler +from toolkit.saving import save_t2i_from_diffusers, load_t2i_model from toolkit.scheduler import get_lr_scheduler from toolkit.sd_device_states_presets import get_train_sd_device_state_preset @@ -34,7 +36,7 @@ import gc from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig def flush(): @@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess): if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + # t2i adapter + self.adapter_config = None + adapter_raw = self.get_conf('adapter', None) + if adapter_raw is not None: + self.adapter_config = AdapterConfig(**adapter_raw) + # sdxl adapters end in _xl. Only full_adapter_xl for now + if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): + self.adapter_config.adapter_type += '_xl' + model_config_to_load = copy.deepcopy(self.model_config) - if self.embed_config is None and self.network_config is None: + if self.embed_config is None and self.network_config is None and self.adapter_config is None: # get the latest checkpoint # check to see if we have a latest save latest_save_path = self.get_latest_save_path() @@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None + self.adapter: Union[T2IAdapter, None] = None self.embedding: Union[Embedding, None] = None # get the device state preset based on what we are training @@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_text_encoder=self.train_config.train_text_encoder, cached_latents=self.is_latents_cached, train_lora=self.network_config is not None, + train_adapter=self.adapter_config is not None, train_embedding=self.embed_config is not None, ) @@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess): # replace extension emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" self.embedding.save(emb_file_path) + elif self.adapter is not None: + # save adapter + state_dict = self.adapter.state_dict() + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) else: self.sd.save( file_path, @@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess): else: return None + def load_training_state_from_metadata(self, path): + meta = load_metadata_from_safetensors(path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + def load_weights(self, path): if self.network is not None: extra_weights = self.network.load_weights(path) - meta = load_metadata_from_safetensors(path) - # if 'training_info' in Orderdict keys - if 'training_info' in meta and 'step' in meta['training_info']: - self.step_num = meta['training_info']['step'] - self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + self.load_training_state_from_metadata(path) return extra_weights else: print("load_weights not implemented for non-network models") return None + def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) + schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) + timesteps = timesteps.to(self.device_torch, ) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): prompts = batch.get_caption_list() @@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs = imgs.to(self.device_torch, dtype=dtype) if batch.latents is not None: latents = batch.latents.to(self.device_torch, dtype=dtype) + batch.latents = latents else: latents = self.sd.encode_images(imgs) + batch.latents = latents flush() batch_size = latents.shape[0] @@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( 1000, device=self.device_torch ) - if self.train_config.use_progressive_denoising: - min_timestep = int(value_map( - self.step_num, - min_in=0, - max_in=self.train_config.max_denoising_steps, - min_out=self.train_config.min_denoising_steps, - max_out=self.train_config.max_denoising_steps - )) - elif self.train_config.use_linear_denoising: - # starts at max steps and walks down to min steps - min_timestep = int(value_map( - self.step_num, - min_in=0, - max_in=self.train_config.max_denoising_steps, - min_out=self.train_config.max_denoising_steps - 1, - max_out=self.train_config.min_denoising_steps - )) + # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': + if self.train_config.content_or_style in ['style', 'content']: + # this is from diffusers training code + # Cubic sampling for favoring later or earlier timesteps + # For more details about why cubic sampling is used for content / structure, + # refer to section 3.4 of https://arxiv.org/abs/2302.08453 + + # for content / structure, it is best to favor earlier timesteps + # for style, it is best to favor later timesteps + + timesteps = torch.rand((batch_size,), device=latents.device) + + if self.train_config.content_or_style == 'style': + timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] + elif self.train_config.content_or_style == 'content': + timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] + + timesteps = value_map( + timesteps, + 0, + self.sd.noise_scheduler.config['num_train_timesteps'] - 1, + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps + ) + timesteps = timesteps.long().clamp( + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps - 1 + ) + + elif self.train_config.content_or_style == 'balanced': + timesteps = torch.randint( + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) + timesteps = timesteps.long() else: - min_timestep = self.train_config.min_denoising_steps - - # todo improve this, but is skews odds for higher timesteps - # 50% chance to use midpoint as the min_time_step - mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 - if torch.rand(1) > 0.5: - min_timestep = mid_point - - # 50% chance to use midpoint as the min_time_step - mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 - if torch.rand(1) > 0.5: - min_timestep = mid_point - - min_timestep = int(min_timestep) - - timesteps = torch.randint( - min_timestep, - self.train_config.max_denoising_steps, - (batch_size,), - device=self.device_torch - ) - timesteps = timesteps.long() + raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}") # get noise noise = self.sd.get_latent_noise( @@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess): return noisy_latents, noise, timesteps, conditioned_prompts, imgs def run(self): + # torch.autograd.set_detect_anomaly(True) # run base process run BaseTrainProcess.run(self) @@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess): # set trainable params params = self.embedding.get_trainable_params() + flush() + elif self.adapter_config is not None: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + # t2i adapter + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + if latest_save_path is not None: + # load adapter from path + print(f"Loading adapter from {latest_save_path}") + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device_torch, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + self.load_training_state_from_metadata(latest_save_path) + params = self.get_params() + if not params: + # set trainable params + params = self.adapter.parameters() + self.sd.adapter = self.adapter flush() else: - # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) diff --git a/requirements.txt b/requirements.txt index 45a304ca..7bd2400f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch torchvision safetensors -git+https://github.com/huggingface/diffusers.git +diffusers==0.21.1 transformers lycoris_lora flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f23a0eed..4979ae57 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -39,6 +39,7 @@ class SampleConfig: NetworkType = Literal['lora', 'locon'] + class NetworkConfig: def __init__(self, **kwargs): self.type: NetworkType = kwargs.get('type', 'lora') @@ -58,6 +59,17 @@ class NetworkConfig: self.dropout: Union[float, None] = kwargs.get('dropout', None) +class AdapterConfig: + def __init__(self, **kwargs): + self.in_channels: int = kwargs.get('in_channels', 3) + self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) + self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) + self.downscale_factor: int = kwargs.get('downscale_factor', 16) + self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') + self.image_dir: str = kwargs.get('image_dir', None) + self.test_img_path: str = kwargs.get('test_img_path', None) + + class EmbeddingConfig: def __init__(self, **kwargs): self.trigger = kwargs.get('trigger', 'custom_embedding') @@ -66,9 +78,13 @@ class EmbeddingConfig: self.save_format = kwargs.get('save_format', 'safetensors') +ContentOrStyleType = Literal['balanced', 'style', 'content'] + + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') + self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') self.steps: int = kwargs.get('steps', 1000) self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) @@ -80,8 +96,6 @@ class TrainConfig: self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) - self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False) - self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False) self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) @@ -255,6 +269,7 @@ class GenerateImageConfig: output_ext: str = ImgExt, # extension to save image as if output_path is not specified output_tail: str = '', # tail to add to output filename add_prompt_file: bool = False, # add a prompt file with generated image + adapter_image_path: str = None, # path to adapter image ): self.width: int = width self.height: int = height @@ -277,6 +292,7 @@ class GenerateImageConfig: self.add_prompt_file: bool = add_prompt_file self.output_tail: str = output_tail self.gen_time: int = int(time.time() * 1000) + self.adapter_image_path: str = adapter_image_path # prompt string will override any settings above self._process_prompt_string() diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index ebfd0b66..925bb3b4 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -341,13 +341,7 @@ class ToolkitNetworkMixin: if isinstance(multiplier, int) or isinstance(multiplier, float): tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) elif isinstance(multiplier, list): - tensor_list = [] - for m in multiplier: - if isinstance(m, int) or isinstance(m, float): - tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype)) - elif isinstance(m, torch.Tensor): - tensor_list.append(m.clone().detach().to(device, dtype=dtype)) - tensor_multiplier = torch.cat(tensor_list) + tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype) elif isinstance(multiplier, torch.Tensor): tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) diff --git a/toolkit/saving.py b/toolkit/saving.py index d978f20f..38c4a1ad 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -161,10 +161,35 @@ def save_lora_from_diffusers( else: converted_key = key + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) +def save_t2i_from_diffusers( + t2i_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), +): + # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() + for key, value in t2i_state_dict.items(): + converted_state_dict[key] = value.detach().to('cpu', dtype=dtype) # make sure parent folder exists os.makedirs(os.path.dirname(output_file), exist_ok=True) - save_file(converted_state_dict, output_file, metadata=meta -) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_t2i_model( + path_to_file, + device: Union[str, torch.device] = 'cpu', + dtype: torch.dtype = torch.float32 +): + raw_state_dict = load_file(path_to_file, device) + converted_state_dict = OrderedDict() + for key, value in raw_state_dict.items(): + # todo see if we need to convert dict + converted_state_dict[key] = value.detach().to(device, dtype=dtype) + return converted_state_dict diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index 9ffbd945..fc16fd9e 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -1,3 +1,5 @@ +from typing import Union + import torch import copy @@ -15,16 +17,22 @@ empty_preset = { 'training': False, 'requires_grad': False, 'device': 'cpu', - } + }, + 'adapter': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, } -def get_train_sd_device_state_preset ( - device: torch.DeviceObjType, +def get_train_sd_device_state_preset( + device: Union[str, torch.device], train_unet: bool = False, train_text_encoder: bool = False, cached_latents: bool = False, train_lora: bool = False, + train_adapter: bool = False, train_embedding: bool = False, ): preset = copy.deepcopy(empty_preset) @@ -51,9 +59,14 @@ def get_train_sd_device_state_preset ( preset['text_encoder']['training'] = True preset['unet']['training'] = True - if train_lora: preset['text_encoder']['requires_grad'] = False preset['unet']['requires_grad'] = False + if train_adapter: + preset['adapter']['requires_grad'] = True + preset['adapter']['training'] = True + preset['adapter']['device'] = device + preset['unet']['training'] = True + return preset diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 77704e48..871f76d3 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,3 +1,4 @@ +import copy import gc import json import shutil @@ -7,6 +8,7 @@ import sys import os from collections import OrderedDict +from PIL import Image from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch.nn import Parameter @@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers +from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline import diffusers # tell it to shut up @@ -110,7 +114,7 @@ class StableDiffusion: self.unet: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] - self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler # sdxl stuff self.logit_scale = None @@ -119,6 +123,7 @@ class StableDiffusion: # to hold network if there is one self.network = None + self.adapter: Union['T2IAdapter', None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 @@ -291,8 +296,18 @@ class StableDiffusion: if sampler.startswith("sample_") and self.is_xl: # using kdiffusion Pipe = StableDiffusionKDiffusionXLPipeline - else: + elif self.is_xl: Pipe = StableDiffusionXLPipeline + else: + Pipe = StableDiffusionPipeline + + extra_args = {} + if self.adapter: + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter # TODO add clip skip if self.is_xl: @@ -305,11 +320,12 @@ class StableDiffusion: tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, add_watermarker=False, + **extra_args ).to(self.device_torch) # force turn that (ruin your images with obvious green and red dots) the #$@@ off!!! pipeline.watermark = None else: - pipeline = StableDiffusionPipeline( + pipeline = Pipe( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder, @@ -318,6 +334,7 @@ class StableDiffusion: safety_checker=None, feature_extractor=None, requires_safety_checker=False, + **extra_args ).to(self.device_torch) flush() # disable progress bar @@ -340,6 +357,12 @@ class StableDiffusion: for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] + extra = {} + if gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + if self.network is not None: self.network.multiplier = gen_config.network_multiplier torch.manual_seed(gen_config.seed) @@ -355,7 +378,6 @@ class StableDiffusion: grs = 0.7 # grs = 0.0 - extra = {} if sampler.startswith("sample_"): extra['use_karras_sigmas'] = True @@ -379,6 +401,7 @@ class StableDiffusion: width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, + **extra ).images[0] gen_config.save_image(img) @@ -517,6 +540,7 @@ class StableDiffusion: timestep, encoder_hidden_states=text_embeddings.text_embeds, added_cond_kwargs=added_cond_kwargs, + **kwargs, ).sample if do_classifier_free_guidance: @@ -558,6 +582,7 @@ class StableDiffusion: latent_model_input, timestep, encoder_hidden_states=text_embeddings.text_embeds, + **kwargs, ).sample if do_classifier_free_guidance: @@ -855,6 +880,7 @@ class StableDiffusion: # saves the current device state for all modules # this is useful for when we want to alter the state and restore it self.device_state = { + **empty_preset, 'vae': { 'training': self.vae.training, 'device': self.vae.device, @@ -880,6 +906,12 @@ class StableDiffusion: 'device': self.text_encoder.device, 'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad } + if self.adapter is not None: + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': self.adapter.device, + 'requires_grad': self.adapter.requires_grad, + } def restore_device_state(self): # restores the device state for all modules @@ -927,6 +959,14 @@ class StableDiffusion: self.text_encoder.eval() self.text_encoder.to(state['text_encoder']['device']) self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() flush() def set_device_state_preset(self, device_state_preset: DeviceStatePreset): @@ -940,9 +980,9 @@ class StableDiffusion: if device_state_preset in ['cache_latents']: active_modules = ['vae'] if device_state_preset in ['generate']: - active_modules = ['vae', 'unet', 'text_encoder'] + active_modules = ['vae', 'unet', 'text_encoder', 'adapter'] - state = {} + state = copy.deepcopy(empty_preset) # vae state['vae'] = { 'training': 'vae' in training_modules, @@ -973,4 +1013,11 @@ class StableDiffusion: 'requires_grad': 'text_encoder' in training_modules, } + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + self.set_device_state(state)