mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications
This commit is contained in:
@@ -31,6 +31,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.decorator import Decorator
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
@@ -59,6 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -165,6 +167,7 @@ class StableDiffusion:
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||
self.decorator: Union[Decorator, None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
@@ -668,7 +671,7 @@ class StableDiffusion:
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
@@ -1209,6 +1212,11 @@ class StableDiffusion:
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.decorator is not None:
|
||||
# apply the decorator to the embeddings
|
||||
conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds)
|
||||
unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
@@ -1566,6 +1574,7 @@ class StableDiffusion:
|
||||
rescale_cfg=None,
|
||||
return_conditional_pred=False,
|
||||
guidance_embedding_scale=1.0,
|
||||
bypass_guidance_embedding=False,
|
||||
**kwargs,
|
||||
):
|
||||
conditional_pred = None
|
||||
@@ -1842,13 +1851,16 @@ class StableDiffusion:
|
||||
|
||||
# # handle guidance
|
||||
if self.unet.config.guidance_embeds:
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance = torch.tensor(guidance_scale, device=self.device_torch)
|
||||
if isinstance(guidance_embedding_scale, list):
|
||||
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
|
||||
else:
|
||||
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||
guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
bypass_flux_guidance(self.unet)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
@@ -1879,6 +1891,9 @@ class StableDiffusion:
|
||||
pw=2,
|
||||
c=latent_model_input.shape[1],
|
||||
)
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
restore_flux_guidance(self.unet)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
|
||||
Reference in New Issue
Block a user