mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications
This commit is contained in:
@@ -934,8 +934,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -1289,6 +1291,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = unconditional_embeds.detach()
|
||||
|
||||
if self.decorator:
|
||||
conditional_embeds.text_embeds = self.decorator(
|
||||
conditional_embeds.text_embeds
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds.text_embeds = self.decorator(
|
||||
unconditional_embeds.text_embeds,
|
||||
is_unconditional=True
|
||||
)
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
|
||||
@@ -34,6 +34,7 @@ from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
|
||||
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
@@ -56,7 +57,8 @@ import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
|
||||
DecoratorConfig
|
||||
from toolkit.logging import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
|
||||
@@ -143,6 +145,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
self.decorator_config: DecoratorConfig = None
|
||||
decorator_raw = self.get_conf('decorator', None)
|
||||
if decorator_raw is not None:
|
||||
if not self.model_config.is_flux:
|
||||
raise ValueError("Decorators are only supported for Flux models currently")
|
||||
self.decorator_config = DecoratorConfig(**decorator_raw)
|
||||
|
||||
# t2i adapter
|
||||
self.adapter_config = None
|
||||
@@ -157,6 +166,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
self.decorator: Union[Decorator, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
@@ -174,6 +184,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=False # we ensure them later
|
||||
@@ -187,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=True # We check for grads when getting params
|
||||
@@ -194,7 +206,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
self.is_fine_tuning = True
|
||||
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
|
||||
if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
self.named_lora = False
|
||||
@@ -468,6 +480,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# replace extension
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
|
||||
if self.decorator is not None:
|
||||
dec_filename = f'{self.job.name}{step_num}.safetensors'
|
||||
dec_file_path = os.path.join(self.save_root, dec_filename)
|
||||
decorator_state_dict = self.decorator.state_dict()
|
||||
for key, value in decorator_state_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype))
|
||||
save_file(
|
||||
decorator_state_dict,
|
||||
dec_file_path,
|
||||
metadata=save_meta,
|
||||
)
|
||||
|
||||
if self.adapter is not None and self.adapter_config.train:
|
||||
adapter_name = self.job.name
|
||||
@@ -1506,6 +1531,30 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
})
|
||||
|
||||
flush()
|
||||
|
||||
if self.decorator_config is not None:
|
||||
self.decorator = Decorator(
|
||||
num_tokens=self.decorator_config.num_tokens,
|
||||
token_size=4096 # t5xxl hidden size for flux
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
state_dict = load_file(latest_save_path)
|
||||
self.decorator.load_state_dict(state_dict)
|
||||
self.load_training_state_from_metadata(path)
|
||||
|
||||
params.append({
|
||||
'params': list(self.decorator.parameters()),
|
||||
'lr': self.train_config.lr
|
||||
})
|
||||
|
||||
# give it to the sd network
|
||||
self.sd.decorator = self.decorator
|
||||
self.decorator.to(self.device_torch, dtype=torch.float32)
|
||||
self.decorator.train()
|
||||
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
|
||||
@@ -227,6 +227,11 @@ class EmbeddingConfig:
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior
|
||||
|
||||
|
||||
class DecoratorConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.num_tokens: str = kwargs.get('num_tokens', 4)
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
@@ -393,6 +398,8 @@ class TrainConfig:
|
||||
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
|
||||
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
|
||||
self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
|
||||
# bypass the guidance embedding for training. For open flux with guidance embedding
|
||||
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -458,6 +465,7 @@ class ModelConfig:
|
||||
# for targeting a specific layers
|
||||
self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None)
|
||||
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
|
||||
self.quantize_kwargs = kwargs.get("quantize_kwargs", {})
|
||||
|
||||
if self.ignore_if_contains is not None or self.only_if_contains is not None:
|
||||
if not self.is_flux:
|
||||
@@ -914,4 +922,6 @@ def validate_configs(
|
||||
if save_config.save_format != 'diffusers':
|
||||
# make it diffusers
|
||||
save_config.save_format = 'diffusers'
|
||||
|
||||
if model_config.use_flux_cfg:
|
||||
# bypass the embedding
|
||||
train_config.bypass_guidance_embedding = True
|
||||
|
||||
@@ -205,7 +205,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'single_value':
|
||||
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
|
||||
elif self.adapter_type == 'redux':
|
||||
self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype)
|
||||
vision_hidden_size = self.vision_encoder.config.hidden_size
|
||||
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
|
||||
33
toolkit/models/decorator.py
Normal file
33
toolkit/models/decorator.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Decorator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_tokens: int = 4,
|
||||
token_size: int = 4096,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.weight: torch.nn.Parameter = torch.nn.Parameter(
|
||||
torch.randn(num_tokens, token_size)
|
||||
)
|
||||
# ensure it is float32
|
||||
self.weight.data = self.weight.data.float()
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor:
|
||||
# make sure the param is float32
|
||||
if self.weight.dtype != text_embeds.dtype:
|
||||
self.weight.data = self.weight.data.float()
|
||||
# expand batch to match text_embeds
|
||||
batch_size = text_embeds.shape[0]
|
||||
decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
if is_unconditional:
|
||||
# zero pad the decorator embeds
|
||||
decorator_embeds = torch.zeros_like(decorator_embeds)
|
||||
|
||||
if decorator_embeds.dtype != text_embeds.dtype:
|
||||
decorator_embeds = decorator_embeds.to(text_embeds.dtype)
|
||||
text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2)
|
||||
|
||||
return text_embeds
|
||||
35
toolkit/models/flux.py
Normal file
35
toolkit/models/flux.py
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
# forward that bypasses the guidance embedding so it can be avoided during training.
|
||||
from functools import partial
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
return conditioning
|
||||
|
||||
# bypass the forward function
|
||||
|
||||
|
||||
def bypass_flux_guidance(transformer):
|
||||
if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
# dont bypass if it doesnt have the guidance embedding
|
||||
if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
|
||||
return
|
||||
transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
|
||||
transformer.time_text_embed.forward = partial(
|
||||
guidance_embed_bypass_forward, transformer.time_text_embed
|
||||
)
|
||||
|
||||
# restore the forward function
|
||||
|
||||
|
||||
def restore_flux_guidance(transformer):
|
||||
if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
|
||||
del transformer.time_text_embed._bfg_orig_forward
|
||||
@@ -264,6 +264,10 @@ class Adafactor(torch.optim.Optimizer):
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"Adafactor does not support sparse gradients.")
|
||||
|
||||
# if p has atts _scale then it is quantized. We need to divide the grad by the scale
|
||||
# if hasattr(p, "_scale"):
|
||||
# grad = grad / p._scale
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
@@ -101,7 +101,7 @@ class Automagic(torch.optim.Optimizer):
|
||||
if 'avg_lr' in param_state:
|
||||
lr = param_state["avg_lr"]
|
||||
else:
|
||||
lr = param_state["lr"]
|
||||
lr = 0.0
|
||||
return lr
|
||||
|
||||
def _get_group_lr(self, group):
|
||||
@@ -332,4 +332,4 @@ class Automagic(torch.optim.Optimizer):
|
||||
state['lr_mask'] = Auto8bitTensor(sd_mask)
|
||||
del state_dict['state'][idx]['lr_mask']
|
||||
idx += 1
|
||||
super().load_state_dict(state_dict, strict)
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
@@ -14,6 +14,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
|
||||
from diffusers.utils import is_torch_xla_available
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -1235,6 +1236,8 @@ class FluxWithCFGPipeline(FluxPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
# bypass the guidance embedding if there is one
|
||||
bypass_flux_guidance(self.transformer)
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
@@ -1410,6 +1413,7 @@ class FluxWithCFGPipeline(FluxPipeline):
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
restore_flux_guidance(self.transformer)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
@@ -39,6 +39,7 @@ def get_train_sd_device_state_preset(
|
||||
train_lora: bool = False,
|
||||
train_adapter: bool = False,
|
||||
train_embedding: bool = False,
|
||||
train_decorator: bool = False,
|
||||
train_refiner: bool = False,
|
||||
unload_text_encoder: bool = False,
|
||||
require_grads: bool = True,
|
||||
@@ -89,6 +90,14 @@ def get_train_sd_device_state_preset(
|
||||
preset['unet']['requires_grad'] = False
|
||||
preset['unet']['device'] = device
|
||||
preset['text_encoder']['device'] = device
|
||||
|
||||
if train_decorator:
|
||||
preset['text_encoder']['training'] = False
|
||||
preset['text_encoder']['requires_grad'] = False
|
||||
preset['text_encoder']['device'] = device
|
||||
preset['unet']['training'] = True
|
||||
preset['unet']['requires_grad'] = False
|
||||
preset['unet']['device'] = device
|
||||
|
||||
if unload_text_encoder:
|
||||
preset['text_encoder']['training'] = False
|
||||
|
||||
@@ -31,6 +31,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
@@ -59,6 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -165,6 +167,7 @@ class StableDiffusion:
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||
self.decorator: Union[Decorator, None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
@@ -668,7 +671,7 @@ class StableDiffusion:
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
@@ -1209,6 +1212,11 @@ class StableDiffusion:
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.decorator is not None:
|
||||
# apply the decorator to the embeddings
|
||||
conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds)
|
||||
unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
@@ -1566,6 +1574,7 @@ class StableDiffusion:
|
||||
rescale_cfg=None,
|
||||
return_conditional_pred=False,
|
||||
guidance_embedding_scale=1.0,
|
||||
bypass_guidance_embedding=False,
|
||||
**kwargs,
|
||||
):
|
||||
conditional_pred = None
|
||||
@@ -1842,13 +1851,16 @@ class StableDiffusion:
|
||||
|
||||
# # handle guidance
|
||||
if self.unet.config.guidance_embeds:
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance = torch.tensor(guidance_scale, device=self.device_torch)
|
||||
if isinstance(guidance_embedding_scale, list):
|
||||
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
|
||||
else:
|
||||
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||
guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
bypass_flux_guidance(self.unet)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
@@ -1879,6 +1891,9 @@ class StableDiffusion:
|
||||
pw=2,
|
||||
c=latent_model_input.shape[1],
|
||||
)
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
restore_flux_guidance(self.unet)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
|
||||
Reference in New Issue
Block a user