Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications

This commit is contained in:
Jaret Burkett
2024-12-15 08:59:27 -07:00
parent 92ce93140e
commit 8ef07a9c36
11 changed files with 182 additions and 10 deletions

View File

@@ -31,6 +31,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
@@ -59,6 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
@@ -165,6 +167,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
@@ -668,7 +671,7 @@ class StableDiffusion:
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=quantization_type)
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
@@ -1209,6 +1212,11 @@ class StableDiffusion:
conditional_embeds,
unconditional_embeds,
)
if self.decorator is not None:
# apply the decorator to the embeddings
conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds)
unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
and gen_config.adapter_image_path is not None:
@@ -1566,6 +1574,7 @@ class StableDiffusion:
rescale_cfg=None,
return_conditional_pred=False,
guidance_embedding_scale=1.0,
bypass_guidance_embedding=False,
**kwargs,
):
conditional_pred = None
@@ -1842,13 +1851,16 @@ class StableDiffusion:
# # handle guidance
if self.unet.config.guidance_embeds:
if isinstance(guidance_scale, list):
guidance = torch.tensor(guidance_scale, device=self.device_torch)
if isinstance(guidance_embedding_scale, list):
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
else:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if bypass_guidance_embedding:
bypass_flux_guidance(self.unet)
cast_dtype = self.unet.dtype
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
@@ -1879,6 +1891,9 @@ class StableDiffusion:
pw=2,
c=latent_model_input.shape[1],
)
if bypass_guidance_embedding:
restore_flux_guidance(self.unet)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),