Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications

This commit is contained in:
Jaret Burkett
2024-12-15 08:59:27 -07:00
parent 92ce93140e
commit 8ef07a9c36
11 changed files with 182 additions and 10 deletions

View File

@@ -934,8 +934,10 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
**kwargs
)
@@ -1289,6 +1291,16 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = conditional_embeds.detach()
if self.train_config.do_cfg:
unconditional_embeds = unconditional_embeds.detach()
if self.decorator:
conditional_embeds.text_embeds = self.decorator(
conditional_embeds.text_embeds
)
if self.train_config.do_cfg:
unconditional_embeds.text_embeds = self.decorator(
unconditional_embeds.text_embeds,
is_unconditional=True
)
# flush()
pred_kwargs = {}

View File

@@ -34,6 +34,7 @@ from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.models.decorator import Decorator
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
@@ -56,7 +57,8 @@ import gc
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
DecoratorConfig
from toolkit.logging import create_logger
from diffusers import FluxTransformer2DModel
@@ -143,6 +145,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
embedding_raw = self.get_conf('embedding', None)
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
self.decorator_config: DecoratorConfig = None
decorator_raw = self.get_conf('decorator', None)
if decorator_raw is not None:
if not self.model_config.is_flux:
raise ValueError("Decorators are only supported for Flux models currently")
self.decorator_config = DecoratorConfig(**decorator_raw)
# t2i adapter
self.adapter_config = None
@@ -157,6 +166,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None
self.embedding: Union[Embedding, None] = None
self.decorator: Union[Decorator, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -174,6 +184,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_lora=self.network_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=False # we ensure them later
@@ -187,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_lora=self.network_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=True # We check for grads when getting params
@@ -194,7 +206,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
self.is_fine_tuning = True
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None:
self.is_fine_tuning = False
self.named_lora = False
@@ -468,6 +480,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
# replace extension
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
if self.decorator is not None:
dec_filename = f'{self.job.name}{step_num}.safetensors'
dec_file_path = os.path.join(self.save_root, dec_filename)
decorator_state_dict = self.decorator.state_dict()
for key, value in decorator_state_dict.items():
if isinstance(value, torch.Tensor):
decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype))
save_file(
decorator_state_dict,
dec_file_path,
metadata=save_meta,
)
if self.adapter is not None and self.adapter_config.train:
adapter_name = self.job.name
@@ -1506,6 +1531,30 @@ class BaseSDTrainProcess(BaseTrainProcess):
})
flush()
if self.decorator_config is not None:
self.decorator = Decorator(
num_tokens=self.decorator_config.num_tokens,
token_size=4096 # t5xxl hidden size for flux
)
latest_save_path = self.get_latest_save_path()
# load last saved weights
if latest_save_path is not None:
state_dict = load_file(latest_save_path)
self.decorator.load_state_dict(state_dict)
self.load_training_state_from_metadata(path)
params.append({
'params': list(self.decorator.parameters()),
'lr': self.train_config.lr
})
# give it to the sd network
self.sd.decorator = self.decorator
self.decorator.to(self.device_torch, dtype=torch.float32)
self.decorator.train()
flush()
if self.adapter_config is not None:
self.setup_adapter()

View File

@@ -227,6 +227,11 @@ class EmbeddingConfig:
self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior
class DecoratorConfig:
def __init__(self, **kwargs):
self.num_tokens: str = kwargs.get('num_tokens', 4)
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
@@ -393,6 +398,8 @@ class TrainConfig:
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
# bypass the guidance embedding for training. For open flux with guidance embedding
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
class ModelConfig:
@@ -458,6 +465,7 @@ class ModelConfig:
# for targeting a specific layers
self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None)
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
self.quantize_kwargs = kwargs.get("quantize_kwargs", {})
if self.ignore_if_contains is not None or self.only_if_contains is not None:
if not self.is_flux:
@@ -914,4 +922,6 @@ def validate_configs(
if save_config.save_format != 'diffusers':
# make it diffusers
save_config.save_format = 'diffusers'
if model_config.use_flux_cfg:
# bypass the embedding
train_config.bypass_guidance_embedding = True

View File

@@ -205,7 +205,8 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'single_value':
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
elif self.adapter_type == 'redux':
self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype)
vision_hidden_size = self.vision_encoder.config.hidden_size
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")

View File

@@ -0,0 +1,33 @@
import torch
class Decorator(torch.nn.Module):
def __init__(
self,
num_tokens: int = 4,
token_size: int = 4096,
) -> None:
super().__init__()
self.weight: torch.nn.Parameter = torch.nn.Parameter(
torch.randn(num_tokens, token_size)
)
# ensure it is float32
self.weight.data = self.weight.data.float()
def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor:
# make sure the param is float32
if self.weight.dtype != text_embeds.dtype:
self.weight.data = self.weight.data.float()
# expand batch to match text_embeds
batch_size = text_embeds.shape[0]
decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1)
if is_unconditional:
# zero pad the decorator embeds
decorator_embeds = torch.zeros_like(decorator_embeds)
if decorator_embeds.dtype != text_embeds.dtype:
decorator_embeds = decorator_embeds.to(text_embeds.dtype)
text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2)
return text_embeds

35
toolkit/models/flux.py Normal file
View File

@@ -0,0 +1,35 @@
# forward that bypasses the guidance embedding so it can be avoided during training.
from functools import partial
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
# bypass the forward function
def bypass_flux_guidance(transformer):
if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
return
# dont bypass if it doesnt have the guidance embedding
if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
return
transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
transformer.time_text_embed.forward = partial(
guidance_embed_bypass_forward, transformer.time_text_embed
)
# restore the forward function
def restore_flux_guidance(transformer):
if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
return
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
del transformer.time_text_embed._bfg_orig_forward

View File

@@ -264,6 +264,10 @@ class Adafactor(torch.optim.Optimizer):
if grad.is_sparse:
raise RuntimeError(
"Adafactor does not support sparse gradients.")
# if p has atts _scale then it is quantized. We need to divide the grad by the scale
# if hasattr(p, "_scale"):
# grad = grad / p._scale
state = self.state[p]
grad_shape = grad.shape

View File

@@ -101,7 +101,7 @@ class Automagic(torch.optim.Optimizer):
if 'avg_lr' in param_state:
lr = param_state["avg_lr"]
else:
lr = param_state["lr"]
lr = 0.0
return lr
def _get_group_lr(self, group):
@@ -332,4 +332,4 @@ class Automagic(torch.optim.Optimizer):
state['lr_mask'] = Auto8bitTensor(sd_mask)
del state_dict['state'][idx]['lr_mask']
idx += 1
super().load_state_dict(state_dict, strict)
super().load_state_dict(state_dict)

View File

@@ -14,6 +14,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
if is_torch_xla_available():
@@ -1235,6 +1236,8 @@ class FluxWithCFGPipeline(FluxPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
# bypass the guidance embedding if there is one
bypass_flux_guidance(self.transformer)
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
@@ -1410,6 +1413,7 @@ class FluxWithCFGPipeline(FluxPipeline):
# Offload all models
self.maybe_free_model_hooks()
restore_flux_guidance(self.transformer)
if not return_dict:
return (image,)

View File

@@ -39,6 +39,7 @@ def get_train_sd_device_state_preset(
train_lora: bool = False,
train_adapter: bool = False,
train_embedding: bool = False,
train_decorator: bool = False,
train_refiner: bool = False,
unload_text_encoder: bool = False,
require_grads: bool = True,
@@ -89,6 +90,14 @@ def get_train_sd_device_state_preset(
preset['unet']['requires_grad'] = False
preset['unet']['device'] = device
preset['text_encoder']['device'] = device
if train_decorator:
preset['text_encoder']['training'] = False
preset['text_encoder']['requires_grad'] = False
preset['text_encoder']['device'] = device
preset['unet']['training'] = True
preset['unet']['requires_grad'] = False
preset['unet']['device'] = device
if unload_text_encoder:
preset['text_encoder']['training'] = False

View File

@@ -31,6 +31,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
@@ -59,6 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
@@ -165,6 +167,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
@@ -668,7 +671,7 @@ class StableDiffusion:
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=quantization_type)
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
@@ -1209,6 +1212,11 @@ class StableDiffusion:
conditional_embeds,
unconditional_embeds,
)
if self.decorator is not None:
# apply the decorator to the embeddings
conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds)
unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
and gen_config.adapter_image_path is not None:
@@ -1566,6 +1574,7 @@ class StableDiffusion:
rescale_cfg=None,
return_conditional_pred=False,
guidance_embedding_scale=1.0,
bypass_guidance_embedding=False,
**kwargs,
):
conditional_pred = None
@@ -1842,13 +1851,16 @@ class StableDiffusion:
# # handle guidance
if self.unet.config.guidance_embeds:
if isinstance(guidance_scale, list):
guidance = torch.tensor(guidance_scale, device=self.device_torch)
if isinstance(guidance_embedding_scale, list):
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
else:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if bypass_guidance_embedding:
bypass_flux_guidance(self.unet)
cast_dtype = self.unet.dtype
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
@@ -1879,6 +1891,9 @@ class StableDiffusion:
pw=2,
c=latent_model_input.shape[1],
)
if bypass_guidance_embedding:
restore_flux_guidance(self.unet)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),