diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 522e5279..4be9bd5c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,17 +1,27 @@ +import os.path from collections import OrderedDict + +from PIL import Image from torch.utils.data import DataLoader + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc import torch from jobs.process import BaseSDTrainProcess +from torchvision import transforms def flush(): torch.cuda.empty_cache() gc.collect() +adapter_transforms = transforms.Compose([ + transforms.PILToTensor(), +]) + class SDTrainer(BaseSDTrainProcess): @@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() + def get_adapter_images(self, batch: 'DataLoaderBatchDTO'): + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + adapter_folder_path = self.adapter_config.image_dir + adapter_images = [] + # loop through images + for file_item in batch.file_items: + img_path = file_item.path + file_name_no_ext = os.path.basename(img_path).split('.')[0] + # find the image + for ext in img_ext_list: + if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): + adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) + break + + adapter_tensors = [] + # load images with torch transforms + for adapter_image in adapter_images: + img = Image.open(adapter_image) + img = adapter_transforms(img) + adapter_tensors.append(img) + + # stack them + adapter_tensors = torch.stack(adapter_tensors) + return adapter_tensors + def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() + + adapter_images = None + sigmas = None + if self.adapter: + # todo move this to data loader + adapter_images = self.get_adapter_images(batch) + # not 100% sure what this does. But they do it here + # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + # flush() self.optimizer.zero_grad() @@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess): # detach the embeddings conditional_embeds = conditional_embeds.detach() # flush() + pred_kwargs = {} + if self.adapter: + down_block_additional_residuals = self.adapter(adapter_images) + down_block_additional_residuals = [ + sample.to(dtype=dtype) for sample in down_block_additional_residuals + ] + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), timestep=timesteps, guidance_scale=1.0, + **pred_kwargs ) - # flush() - # 9.18 gb - noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.adapter: + # todo, diffusers does this on t2i training, is it better approach? + # Denoise the latents + denoised_latents = noise_pred * (-sigmas) + noisy_latents + weighing = sigmas ** -2.0 - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = batch.latents # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # MSE loss + loss = torch.mean( + (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1), + dim=1, + ) else: - target = noise + noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + # TODO: I think the sigma method does not need this. Check + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5d35ffb3..5f0049af 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,7 @@ from collections import OrderedDict import os from typing import Union +from diffusers import T2IAdapter # from lycoris.config import PRESET from torch.utils.data import DataLoader import torch @@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT from toolkit.progress_bar import ToolkitProgressBar from toolkit.sampler import get_sampler +from toolkit.saving import save_t2i_from_diffusers, load_t2i_model from toolkit.scheduler import get_lr_scheduler from toolkit.sd_device_states_presets import get_train_sd_device_state_preset @@ -34,7 +36,7 @@ import gc from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig def flush(): @@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess): if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + # t2i adapter + self.adapter_config = None + adapter_raw = self.get_conf('adapter', None) + if adapter_raw is not None: + self.adapter_config = AdapterConfig(**adapter_raw) + # sdxl adapters end in _xl. Only full_adapter_xl for now + if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): + self.adapter_config.adapter_type += '_xl' + model_config_to_load = copy.deepcopy(self.model_config) - if self.embed_config is None and self.network_config is None: + if self.embed_config is None and self.network_config is None and self.adapter_config is None: # get the latest checkpoint # check to see if we have a latest save latest_save_path = self.get_latest_save_path() @@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None + self.adapter: Union[T2IAdapter, None] = None self.embedding: Union[Embedding, None] = None # get the device state preset based on what we are training @@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_text_encoder=self.train_config.train_text_encoder, cached_latents=self.is_latents_cached, train_lora=self.network_config is not None, + train_adapter=self.adapter_config is not None, train_embedding=self.embed_config is not None, ) @@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess): # replace extension emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" self.embedding.save(emb_file_path) + elif self.adapter is not None: + # save adapter + state_dict = self.adapter.state_dict() + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) else: self.sd.save( file_path, @@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess): else: return None + def load_training_state_from_metadata(self, path): + meta = load_metadata_from_safetensors(path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + def load_weights(self, path): if self.network is not None: extra_weights = self.network.load_weights(path) - meta = load_metadata_from_safetensors(path) - # if 'training_info' in Orderdict keys - if 'training_info' in meta and 'step' in meta['training_info']: - self.step_num = meta['training_info']['step'] - self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + self.load_training_state_from_metadata(path) return extra_weights else: print("load_weights not implemented for non-network models") return None + def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) + schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) + timesteps = timesteps.to(self.device_torch, ) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): prompts = batch.get_caption_list() @@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs = imgs.to(self.device_torch, dtype=dtype) if batch.latents is not None: latents = batch.latents.to(self.device_torch, dtype=dtype) + batch.latents = latents else: latents = self.sd.encode_images(imgs) + batch.latents = latents flush() batch_size = latents.shape[0] @@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( 1000, device=self.device_torch ) - if self.train_config.use_progressive_denoising: - min_timestep = int(value_map( - self.step_num, - min_in=0, - max_in=self.train_config.max_denoising_steps, - min_out=self.train_config.min_denoising_steps, - max_out=self.train_config.max_denoising_steps - )) - elif self.train_config.use_linear_denoising: - # starts at max steps and walks down to min steps - min_timestep = int(value_map( - self.step_num, - min_in=0, - max_in=self.train_config.max_denoising_steps, - min_out=self.train_config.max_denoising_steps - 1, - max_out=self.train_config.min_denoising_steps - )) + # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': + if self.train_config.content_or_style in ['style', 'content']: + # this is from diffusers training code + # Cubic sampling for favoring later or earlier timesteps + # For more details about why cubic sampling is used for content / structure, + # refer to section 3.4 of https://arxiv.org/abs/2302.08453 + + # for content / structure, it is best to favor earlier timesteps + # for style, it is best to favor later timesteps + + timesteps = torch.rand((batch_size,), device=latents.device) + + if self.train_config.content_or_style == 'style': + timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] + elif self.train_config.content_or_style == 'content': + timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] + + timesteps = value_map( + timesteps, + 0, + self.sd.noise_scheduler.config['num_train_timesteps'] - 1, + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps + ) + timesteps = timesteps.long().clamp( + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps - 1 + ) + + elif self.train_config.content_or_style == 'balanced': + timesteps = torch.randint( + self.train_config.min_denoising_steps, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) + timesteps = timesteps.long() else: - min_timestep = self.train_config.min_denoising_steps - - # todo improve this, but is skews odds for higher timesteps - # 50% chance to use midpoint as the min_time_step - mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 - if torch.rand(1) > 0.5: - min_timestep = mid_point - - # 50% chance to use midpoint as the min_time_step - mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2 - if torch.rand(1) > 0.5: - min_timestep = mid_point - - min_timestep = int(min_timestep) - - timesteps = torch.randint( - min_timestep, - self.train_config.max_denoising_steps, - (batch_size,), - device=self.device_torch - ) - timesteps = timesteps.long() + raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}") # get noise noise = self.sd.get_latent_noise( @@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess): return noisy_latents, noise, timesteps, conditioned_prompts, imgs def run(self): + # torch.autograd.set_detect_anomaly(True) # run base process run BaseTrainProcess.run(self) @@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess): # set trainable params params = self.embedding.get_trainable_params() + flush() + elif self.adapter_config is not None: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + # t2i adapter + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + if latest_save_path is not None: + # load adapter from path + print(f"Loading adapter from {latest_save_path}") + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device_torch, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + self.load_training_state_from_metadata(latest_save_path) + params = self.get_params() + if not params: + # set trainable params + params = self.adapter.parameters() + self.sd.adapter = self.adapter flush() else: - # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) diff --git a/requirements.txt b/requirements.txt index 45a304ca..7bd2400f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch torchvision safetensors -git+https://github.com/huggingface/diffusers.git +diffusers==0.21.1 transformers lycoris_lora flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f23a0eed..4979ae57 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -39,6 +39,7 @@ class SampleConfig: NetworkType = Literal['lora', 'locon'] + class NetworkConfig: def __init__(self, **kwargs): self.type: NetworkType = kwargs.get('type', 'lora') @@ -58,6 +59,17 @@ class NetworkConfig: self.dropout: Union[float, None] = kwargs.get('dropout', None) +class AdapterConfig: + def __init__(self, **kwargs): + self.in_channels: int = kwargs.get('in_channels', 3) + self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) + self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) + self.downscale_factor: int = kwargs.get('downscale_factor', 16) + self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') + self.image_dir: str = kwargs.get('image_dir', None) + self.test_img_path: str = kwargs.get('test_img_path', None) + + class EmbeddingConfig: def __init__(self, **kwargs): self.trigger = kwargs.get('trigger', 'custom_embedding') @@ -66,9 +78,13 @@ class EmbeddingConfig: self.save_format = kwargs.get('save_format', 'safetensors') +ContentOrStyleType = Literal['balanced', 'style', 'content'] + + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') + self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') self.steps: int = kwargs.get('steps', 1000) self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) @@ -80,8 +96,6 @@ class TrainConfig: self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) - self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False) - self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False) self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) @@ -255,6 +269,7 @@ class GenerateImageConfig: output_ext: str = ImgExt, # extension to save image as if output_path is not specified output_tail: str = '', # tail to add to output filename add_prompt_file: bool = False, # add a prompt file with generated image + adapter_image_path: str = None, # path to adapter image ): self.width: int = width self.height: int = height @@ -277,6 +292,7 @@ class GenerateImageConfig: self.add_prompt_file: bool = add_prompt_file self.output_tail: str = output_tail self.gen_time: int = int(time.time() * 1000) + self.adapter_image_path: str = adapter_image_path # prompt string will override any settings above self._process_prompt_string() diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index ebfd0b66..925bb3b4 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -341,13 +341,7 @@ class ToolkitNetworkMixin: if isinstance(multiplier, int) or isinstance(multiplier, float): tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) elif isinstance(multiplier, list): - tensor_list = [] - for m in multiplier: - if isinstance(m, int) or isinstance(m, float): - tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype)) - elif isinstance(m, torch.Tensor): - tensor_list.append(m.clone().detach().to(device, dtype=dtype)) - tensor_multiplier = torch.cat(tensor_list) + tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype) elif isinstance(multiplier, torch.Tensor): tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) diff --git a/toolkit/saving.py b/toolkit/saving.py index d978f20f..38c4a1ad 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -161,10 +161,35 @@ def save_lora_from_diffusers( else: converted_key = key + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) +def save_t2i_from_diffusers( + t2i_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), +): + # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() + for key, value in t2i_state_dict.items(): + converted_state_dict[key] = value.detach().to('cpu', dtype=dtype) # make sure parent folder exists os.makedirs(os.path.dirname(output_file), exist_ok=True) - save_file(converted_state_dict, output_file, metadata=meta -) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_t2i_model( + path_to_file, + device: Union[str, torch.device] = 'cpu', + dtype: torch.dtype = torch.float32 +): + raw_state_dict = load_file(path_to_file, device) + converted_state_dict = OrderedDict() + for key, value in raw_state_dict.items(): + # todo see if we need to convert dict + converted_state_dict[key] = value.detach().to(device, dtype=dtype) + return converted_state_dict diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index 9ffbd945..fc16fd9e 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -1,3 +1,5 @@ +from typing import Union + import torch import copy @@ -15,16 +17,22 @@ empty_preset = { 'training': False, 'requires_grad': False, 'device': 'cpu', - } + }, + 'adapter': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, } -def get_train_sd_device_state_preset ( - device: torch.DeviceObjType, +def get_train_sd_device_state_preset( + device: Union[str, torch.device], train_unet: bool = False, train_text_encoder: bool = False, cached_latents: bool = False, train_lora: bool = False, + train_adapter: bool = False, train_embedding: bool = False, ): preset = copy.deepcopy(empty_preset) @@ -51,9 +59,14 @@ def get_train_sd_device_state_preset ( preset['text_encoder']['training'] = True preset['unet']['training'] = True - if train_lora: preset['text_encoder']['requires_grad'] = False preset['unet']['requires_grad'] = False + if train_adapter: + preset['adapter']['requires_grad'] = True + preset['adapter']['training'] = True + preset['adapter']['device'] = device + preset['unet']['training'] = True + return preset diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 77704e48..871f76d3 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,3 +1,4 @@ +import copy import gc import json import shutil @@ -7,6 +8,7 @@ import sys import os from collections import OrderedDict +from PIL import Image from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch.nn import Parameter @@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers +from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline import diffusers # tell it to shut up @@ -110,7 +114,7 @@ class StableDiffusion: self.unet: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] - self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler # sdxl stuff self.logit_scale = None @@ -119,6 +123,7 @@ class StableDiffusion: # to hold network if there is one self.network = None + self.adapter: Union['T2IAdapter', None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 @@ -291,8 +296,18 @@ class StableDiffusion: if sampler.startswith("sample_") and self.is_xl: # using kdiffusion Pipe = StableDiffusionKDiffusionXLPipeline - else: + elif self.is_xl: Pipe = StableDiffusionXLPipeline + else: + Pipe = StableDiffusionPipeline + + extra_args = {} + if self.adapter: + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter # TODO add clip skip if self.is_xl: @@ -305,11 +320,12 @@ class StableDiffusion: tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, add_watermarker=False, + **extra_args ).to(self.device_torch) # force turn that (ruin your images with obvious green and red dots) the #$@@ off!!! pipeline.watermark = None else: - pipeline = StableDiffusionPipeline( + pipeline = Pipe( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder, @@ -318,6 +334,7 @@ class StableDiffusion: safety_checker=None, feature_extractor=None, requires_safety_checker=False, + **extra_args ).to(self.device_torch) flush() # disable progress bar @@ -340,6 +357,12 @@ class StableDiffusion: for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] + extra = {} + if gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + if self.network is not None: self.network.multiplier = gen_config.network_multiplier torch.manual_seed(gen_config.seed) @@ -355,7 +378,6 @@ class StableDiffusion: grs = 0.7 # grs = 0.0 - extra = {} if sampler.startswith("sample_"): extra['use_karras_sigmas'] = True @@ -379,6 +401,7 @@ class StableDiffusion: width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, + **extra ).images[0] gen_config.save_image(img) @@ -517,6 +540,7 @@ class StableDiffusion: timestep, encoder_hidden_states=text_embeddings.text_embeds, added_cond_kwargs=added_cond_kwargs, + **kwargs, ).sample if do_classifier_free_guidance: @@ -558,6 +582,7 @@ class StableDiffusion: latent_model_input, timestep, encoder_hidden_states=text_embeddings.text_embeds, + **kwargs, ).sample if do_classifier_free_guidance: @@ -855,6 +880,7 @@ class StableDiffusion: # saves the current device state for all modules # this is useful for when we want to alter the state and restore it self.device_state = { + **empty_preset, 'vae': { 'training': self.vae.training, 'device': self.vae.device, @@ -880,6 +906,12 @@ class StableDiffusion: 'device': self.text_encoder.device, 'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad } + if self.adapter is not None: + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': self.adapter.device, + 'requires_grad': self.adapter.requires_grad, + } def restore_device_state(self): # restores the device state for all modules @@ -927,6 +959,14 @@ class StableDiffusion: self.text_encoder.eval() self.text_encoder.to(state['text_encoder']['device']) self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() flush() def set_device_state_preset(self, device_state_preset: DeviceStatePreset): @@ -940,9 +980,9 @@ class StableDiffusion: if device_state_preset in ['cache_latents']: active_modules = ['vae'] if device_state_preset in ['generate']: - active_modules = ['vae', 'unet', 'text_encoder'] + active_modules = ['vae', 'unet', 'text_encoder', 'adapter'] - state = {} + state = copy.deepcopy(empty_preset) # vae state['vae'] = { 'training': 'vae' in training_modules, @@ -973,4 +1013,11 @@ class StableDiffusion: 'requires_grad': 'text_encoder' in training_modules, } + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + self.set_device_state(state)