From 8ef07a9c3662351328407d9249a5418639dd3052 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 15 Dec 2024 08:59:27 -0700 Subject: [PATCH] Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications --- extensions_built_in/sd_trainer/SDTrainer.py | 12 +++++ jobs/process/BaseSDTrainProcess.py | 53 ++++++++++++++++++++- toolkit/config_modules.py | 12 ++++- toolkit/custom_adapter.py | 3 +- toolkit/models/decorator.py | 33 +++++++++++++ toolkit/models/flux.py | 35 ++++++++++++++ toolkit/optimizers/adafactor.py | 4 ++ toolkit/optimizers/automagic.py | 4 +- toolkit/pipelines.py | 4 ++ toolkit/sd_device_states_presets.py | 9 ++++ toolkit/stable_diffusion_model.py | 23 +++++++-- 11 files changed, 182 insertions(+), 10 deletions(-) create mode 100644 toolkit/models/decorator.py create mode 100644 toolkit/models/flux.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index fb8ac6d8..2a4d051d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -934,8 +934,10 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=self.train_config.cfg_scale, detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, + bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, **kwargs ) @@ -1289,6 +1291,16 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = conditional_embeds.detach() if self.train_config.do_cfg: unconditional_embeds = unconditional_embeds.detach() + + if self.decorator: + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds + ) + if self.train_config.do_cfg: + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, + is_unconditional=True + ) # flush() pred_kwargs = {} diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1e1c36de..87984717 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -34,6 +34,7 @@ from toolkit.lora_special import LoRASpecialNetwork from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE from toolkit.lycoris_special import LycorisSpecialNetwork +from toolkit.models.decorator import Decorator from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT @@ -56,7 +57,8 @@ import gc from tqdm import tqdm from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \ + DecoratorConfig from toolkit.logging import create_logger from diffusers import FluxTransformer2DModel @@ -143,6 +145,13 @@ class BaseSDTrainProcess(BaseTrainProcess): embedding_raw = self.get_conf('embedding', None) if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + + self.decorator_config: DecoratorConfig = None + decorator_raw = self.get_conf('decorator', None) + if decorator_raw is not None: + if not self.model_config.is_flux: + raise ValueError("Decorators are only supported for Flux models currently") + self.decorator_config = DecoratorConfig(**decorator_raw) # t2i adapter self.adapter_config = None @@ -157,6 +166,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network: Union[Network, None] = None self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None self.embedding: Union[Embedding, None] = None + self.decorator: Union[Decorator, None] = None is_training_adapter = self.adapter_config is not None and self.adapter_config.train @@ -174,6 +184,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_lora=self.network_config is not None, train_adapter=is_training_adapter, train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, unload_text_encoder=self.train_config.unload_text_encoder, require_grads=False # we ensure them later @@ -187,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess): train_lora=self.network_config is not None, train_adapter=is_training_adapter, train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, unload_text_encoder=self.train_config.unload_text_encoder, require_grads=True # We check for grads when getting params @@ -194,7 +206,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) self.is_fine_tuning = True - if self.network_config is not None or is_training_adapter or self.embed_config is not None: + if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None: self.is_fine_tuning = False self.named_lora = False @@ -468,6 +480,19 @@ class BaseSDTrainProcess(BaseTrainProcess): # replace extension emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" self.embedding.save(emb_file_path) + + if self.decorator is not None: + dec_filename = f'{self.job.name}{step_num}.safetensors' + dec_file_path = os.path.join(self.save_root, dec_filename) + decorator_state_dict = self.decorator.state_dict() + for key, value in decorator_state_dict.items(): + if isinstance(value, torch.Tensor): + decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype)) + save_file( + decorator_state_dict, + dec_file_path, + metadata=save_meta, + ) if self.adapter is not None and self.adapter_config.train: adapter_name = self.job.name @@ -1506,6 +1531,30 @@ class BaseSDTrainProcess(BaseTrainProcess): }) flush() + + if self.decorator_config is not None: + self.decorator = Decorator( + num_tokens=self.decorator_config.num_tokens, + token_size=4096 # t5xxl hidden size for flux + ) + latest_save_path = self.get_latest_save_path() + # load last saved weights + if latest_save_path is not None: + state_dict = load_file(latest_save_path) + self.decorator.load_state_dict(state_dict) + self.load_training_state_from_metadata(path) + + params.append({ + 'params': list(self.decorator.parameters()), + 'lr': self.train_config.lr + }) + + # give it to the sd network + self.sd.decorator = self.decorator + self.decorator.to(self.device_torch, dtype=torch.float32) + self.decorator.train() + + flush() if self.adapter_config is not None: self.setup_adapter() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5d9dc92b..1e7215bf 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -227,6 +227,11 @@ class EmbeddingConfig: self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior +class DecoratorConfig: + def __init__(self, **kwargs): + self.num_tokens: str = kwargs.get('num_tokens', 4) + + ContentOrStyleType = Literal['balanced', 'style', 'content'] LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] @@ -393,6 +398,8 @@ class TrainConfig: self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1) + # bypass the guidance embedding for training. For open flux with guidance embedding + self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) class ModelConfig: @@ -458,6 +465,7 @@ class ModelConfig: # for targeting a specific layers self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None) self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None) + self.quantize_kwargs = kwargs.get("quantize_kwargs", {}) if self.ignore_if_contains is not None or self.only_if_contains is not None: if not self.is_flux: @@ -914,4 +922,6 @@ def validate_configs( if save_config.save_format != 'diffusers': # make it diffusers save_config.save_format = 'diffusers' - \ No newline at end of file + if model_config.use_flux_cfg: + # bypass the embedding + train_config.bypass_guidance_embedding = True diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 6ca66020..12a4df4b 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -205,7 +205,8 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'single_value': self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens) elif self.adapter_type == 'redux': - self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype) + vision_hidden_size = self.vision_encoder.config.hidden_size + self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") diff --git a/toolkit/models/decorator.py b/toolkit/models/decorator.py new file mode 100644 index 00000000..63f45aa9 --- /dev/null +++ b/toolkit/models/decorator.py @@ -0,0 +1,33 @@ +import torch + + +class Decorator(torch.nn.Module): + def __init__( + self, + num_tokens: int = 4, + token_size: int = 4096, + ) -> None: + super().__init__() + + self.weight: torch.nn.Parameter = torch.nn.Parameter( + torch.randn(num_tokens, token_size) + ) + # ensure it is float32 + self.weight.data = self.weight.data.float() + + def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor: + # make sure the param is float32 + if self.weight.dtype != text_embeds.dtype: + self.weight.data = self.weight.data.float() + # expand batch to match text_embeds + batch_size = text_embeds.shape[0] + decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1) + if is_unconditional: + # zero pad the decorator embeds + decorator_embeds = torch.zeros_like(decorator_embeds) + + if decorator_embeds.dtype != text_embeds.dtype: + decorator_embeds = decorator_embeds.to(text_embeds.dtype) + text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2) + + return text_embeds diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py new file mode 100644 index 00000000..48ce8786 --- /dev/null +++ b/toolkit/models/flux.py @@ -0,0 +1,35 @@ + +# forward that bypasses the guidance embedding so it can be avoided during training. +from functools import partial + + +def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + return conditioning + +# bypass the forward function + + +def bypass_flux_guidance(transformer): + if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + # dont bypass if it doesnt have the guidance embedding + if not hasattr(transformer.time_text_embed, 'guidance_embedder'): + return + transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward + transformer.time_text_embed.forward = partial( + guidance_embed_bypass_forward, transformer.time_text_embed + ) + +# restore the forward function + + +def restore_flux_guidance(transformer): + if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward + del transformer.time_text_embed._bfg_orig_forward diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 2f1a8997..00cf06ee 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -264,6 +264,10 @@ class Adafactor(torch.optim.Optimizer): if grad.is_sparse: raise RuntimeError( "Adafactor does not support sparse gradients.") + + # if p has atts _scale then it is quantized. We need to divide the grad by the scale + # if hasattr(p, "_scale"): + # grad = grad / p._scale state = self.state[p] grad_shape = grad.shape diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py index 638a260b..ac7355f1 100644 --- a/toolkit/optimizers/automagic.py +++ b/toolkit/optimizers/automagic.py @@ -101,7 +101,7 @@ class Automagic(torch.optim.Optimizer): if 'avg_lr' in param_state: lr = param_state["avg_lr"] else: - lr = param_state["lr"] + lr = 0.0 return lr def _get_group_lr(self, group): @@ -332,4 +332,4 @@ class Automagic(torch.optim.Optimizer): state['lr_mask'] = Auto8bitTensor(sd_mask) del state_dict['state'][idx]['lr_mask'] idx += 1 - super().load_state_dict(state_dict, strict) + super().load_state_dict(state_dict) diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index e2efa3ee..c0509ee1 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -14,6 +14,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import from diffusers.utils import is_torch_xla_available from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler +from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance if is_torch_xla_available(): @@ -1235,6 +1236,8 @@ class FluxWithCFGPipeline(FluxPipeline): callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): + # bypass the guidance embedding if there is one + bypass_flux_guidance(self.transformer) height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -1410,6 +1413,7 @@ class FluxWithCFGPipeline(FluxPipeline): # Offload all models self.maybe_free_model_hooks() + restore_flux_guidance(self.transformer) if not return_dict: return (image,) diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index 0a8918ef..1eeecc32 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -39,6 +39,7 @@ def get_train_sd_device_state_preset( train_lora: bool = False, train_adapter: bool = False, train_embedding: bool = False, + train_decorator: bool = False, train_refiner: bool = False, unload_text_encoder: bool = False, require_grads: bool = True, @@ -89,6 +90,14 @@ def get_train_sd_device_state_preset( preset['unet']['requires_grad'] = False preset['unet']['device'] = device preset['text_encoder']['device'] = device + + if train_decorator: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = device + preset['unet']['training'] = True + preset['unet']['requires_grad'] = False + preset['unet']['device'] = device if unload_text_encoder: preset['text_encoder']['training'] = False diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 2911c646..23439fbc 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -31,6 +31,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.decorator import Decorator from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter @@ -59,6 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download +from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from typing import TYPE_CHECKING @@ -165,6 +167,7 @@ class StableDiffusion: # to hold network if there is one self.network = None self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 self.is_ssd = model_config.is_ssd @@ -668,7 +671,7 @@ class StableDiffusion: patch_dequantization_on_save(transformer) quantization_type = qfloat8 print("Quantizing transformer") - quantize(transformer, weights=quantization_type) + quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) transformer.to(self.device_torch) else: @@ -1209,6 +1212,11 @@ class StableDiffusion: conditional_embeds, unconditional_embeds, ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True) if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ and gen_config.adapter_image_path is not None: @@ -1566,6 +1574,7 @@ class StableDiffusion: rescale_cfg=None, return_conditional_pred=False, guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, **kwargs, ): conditional_pred = None @@ -1842,13 +1851,16 @@ class StableDiffusion: # # handle guidance if self.unet.config.guidance_embeds: - if isinstance(guidance_scale, list): - guidance = torch.tensor(guidance_scale, device=self.device_torch) + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch) else: - guidance = torch.tensor([guidance_scale], device=self.device_torch) + guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch) guidance = guidance.expand(latents.shape[0]) else: guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) cast_dtype = self.unet.dtype # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): @@ -1879,6 +1891,9 @@ class StableDiffusion: pw=2, c=latent_model_input.shape[1], ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),