fixed huge flux training bug. Added ability to use an assistatn lora

This commit is contained in:
Jaret Burkett
2024-08-14 10:14:13 -06:00
parent e07bf11727
commit 7fed4ea761
4 changed files with 124 additions and 49 deletions

55
toolkit/assistant_lora.py Normal file
View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from toolkit.config_modules import NetworkConfig
from toolkit.lora_special import LoRASpecialNetwork
from safetensors.torch import load_file
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
if not sd.is_flux:
raise ValueError("Only Flux models can load assistant adapters currently.")
pipe = sd.pipeline
print(f"Loading assistant adapter from {adapter_path}")
adapter_name = adapter_path.split("/")[-1].split(".")[0]
lora_state_dict = load_file(adapter_path)
linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
# linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
linear_alpha = linear_dim
transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
# get dim and scale
network_config = NetworkConfig(
linear=linear_dim,
linear_alpha=linear_alpha,
transformer_only=transformer_only,
)
network = LoRASpecialNetwork(
text_encoder=pipe.text_encoder,
unet=pipe.transformer,
lora_dim=network_config.linear,
multiplier=1.0,
alpha=network_config.linear_alpha,
train_unet=True,
train_text_encoder=False,
is_flux=True,
network_config=network_config,
network_type=network_config.type,
transformer_only=network_config.transformer_only,
is_assistant_adapter=True
)
network.apply_to(
pipe.text_encoder,
pipe.transformer,
apply_text_encoder=False,
apply_unet=True
)
network.force_to(sd.device_torch, dtype=sd.torch_dtype)
network.eval()
network._update_torch_multiplier()
network.load_weights(lora_state_dict)
network.is_active = True
return network

View File

@@ -175,6 +175,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
full_train_in_out: bool = False, full_train_in_out: bool = False,
transformer_only: bool = False, transformer_only: bool = False,
peft_format: bool = False, peft_format: bool = False,
is_assistant_adapter: bool = False,
**kwargs **kwargs
) -> None: ) -> None:
""" """
@@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_auraflow = is_auraflow self.is_auraflow = is_auraflow
self.is_flux = is_flux self.is_flux = is_flux
self.network_type = network_type self.network_type = network_type
self.is_assistant_adapter = is_assistant_adapter
if self.network_type.lower() == "dora": if self.network_type.lower() == "dora":
self.module_class = DoRAModule self.module_class = DoRAModule
module_class = DoRAModule module_class = DoRAModule

View File

@@ -263,7 +263,7 @@ class ToolkitModuleMixin:
if isinstance(x, QTensor): if isinstance(x, QTensor):
x = x.dequantize() x = x.dequantize()
# always cast to float32 # always cast to float32
lora_input = x.float() lora_input = x.to(self.lora_down.weight.dtype)
lora_output = self._call_forward(lora_input) lora_output = self._call_forward(lora_input)
multiplier = self.network_ref().torch_multiplier multiplier = self.network_ref().torch_multiplier

View File

@@ -11,7 +11,8 @@ from collections import OrderedDict
import copy import copy
import yaml import yaml
from PIL import Image from PIL import Image
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \
ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file from safetensors.torch import save_file, load_file
from torch import autocast from torch import autocast
@@ -20,6 +21,7 @@ from torch.utils.checkpoint import checkpoint
from tqdm import tqdm from tqdm import tqdm
from torchvision.transforms import Resize, transforms from torchvision.transforms import Resize, transforms
from toolkit.assistant_lora import load_assistant_lora_from_path
from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter from toolkit.custom_adapter import CustomAdapter
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
@@ -57,6 +59,10 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# tell it to shut up # tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR) diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -84,7 +90,6 @@ DO_NOT_TRAIN_WEIGHTS = [
DeviceStatePreset = Literal['cache_latents', 'generate'] DeviceStatePreset = Literal['cache_latents', 'generate']
class BlankNetwork: class BlankNetwork:
def __init__(self): def __init__(self):
@@ -127,10 +132,12 @@ class StableDiffusion:
self.torch_dtype = get_torch_dtype(dtype) self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = torch.device(self.device) self.device_torch = torch.device(self.device)
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device) self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
model_config.vae_device)
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device) self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
model_config.te_device)
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
self.model_config = model_config self.model_config = model_config
@@ -146,6 +153,7 @@ class StableDiffusion:
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None
# sdxl stuff # sdxl stuff
self.logit_scale = None self.logit_scale = None
@@ -270,7 +278,7 @@ class StableDiffusion:
# see if path exists # see if path exists
if not os.path.exists(model_path) or os.path.isdir(model_path): if not os.path.exists(model_path) or os.path.isdir(model_path):
try: try:
# try to load with default diffusers # try to load with default diffusers
pipe = pipln.from_pretrained( pipe = pipln.from_pretrained(
model_path, model_path,
dtype=dtype, dtype=dtype,
@@ -462,6 +470,8 @@ class StableDiffusion:
print("Loading transformer") print("Loading transformer")
subfolder = 'transformer' subfolder = 'transformer'
transformer_path = model_path transformer_path = model_path
local_files_only = False
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
if os.path.exists(transformer_path): if os.path.exists(transformer_path):
subfolder = None subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer') transformer_path = os.path.join(transformer_path, 'transformer')
@@ -518,7 +528,8 @@ class StableDiffusion:
print("Loading t5") print("Loading t5")
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
torch_dtype=dtype)
text_encoder_2.to(self.device_torch, dtype=dtype) text_encoder_2.to(self.device_torch, dtype=dtype)
flush() flush()
@@ -655,21 +666,17 @@ class StableDiffusion:
# unfortunately, not an easier way with peft # unfortunately, not an easier way with peft
pipe.unload_lora_weights() pipe.unload_lora_weights()
if self.model_config.assistant_lora_path is not None:
if self.model_config.lora_path is not None:
raise ValueError("Cannot have both lora and assistant lora")
print("Loading assistant lora")
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
pipe.fuse_lora(lora_scale=1.0)
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.pipeline = pipe self.pipeline = pipe
self.load_refiner() self.load_refiner()
self.is_loaded = True self.is_loaded = True
if self.model_config.assistant_lora_path is not None:
print("Loading assistant lora")
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
self.model_config.assistant_lora_path, self)
if self.is_pixart and self.vae_scale_factor == 16: if self.is_pixart and self.vae_scale_factor == 16:
# TODO make our own pipeline? # TODO make our own pipeline?
# we generate an image 2x larger, so we need to copy the sizes from larger ones down # we generate an image 2x larger, so we need to copy the sizes from larger ones down
@@ -741,9 +748,7 @@ class StableDiffusion:
if self.model_config.assistant_lora_path is not None: if self.model_config.assistant_lora_path is not None:
print("Unloading asistant lora") print("Unloading asistant lora")
# unfortunately, not an easier way with peft # unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") self.assistant_lora.is_active = False
self.pipeline.fuse_lora(lora_scale=-1.0)
self.pipeline.unload_lora_weights()
if self.network is not None: if self.network is not None:
self.network.eval() self.network.eval()
@@ -1027,7 +1032,6 @@ class StableDiffusion:
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
and gen_config.adapter_image_path is not None: and gen_config.adapter_image_path is not None:
# apply the image projection # apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
@@ -1035,7 +1039,8 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: if self.adapter is not None and isinstance(self.adapter,
CustomAdapter) and validation_image is not None:
conditional_embeds = self.adapter.condition_encoded_embeds( conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image, tensors_0_1=validation_image,
prompt_embeds=conditional_embeds, prompt_embeds=conditional_embeds,
@@ -1052,13 +1057,14 @@ class StableDiffusion:
is_generating_samples=True, is_generating_samples=True,
) )
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(gen_config.extra_values) > 0: if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype) gen_config.extra_values) > 0:
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch,
dtype=self.torch_dtype)
# apply extra values to the embeddings # apply extra values to the embeddings
self.adapter.add_extra_values(extra_values, is_unconditional=False) self.adapter.add_extra_values(extra_values, is_unconditional=False)
self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
pass # todo remove, for debugging pass # todo remove, for debugging
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
# if we have a refiner loaded, set the denoising end at the refiner start # if we have a refiner loaded, set the denoising end at the refiner start
@@ -1148,9 +1154,12 @@ class StableDiffusion:
img = pipeline( img = pipeline(
prompt=None, prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
dtype=self.unet.dtype),
negative_prompt=None, negative_prompt=None,
# negative_prompt=gen_config.negative_prompt, # negative_prompt=gen_config.negative_prompt,
height=gen_config.height, height=gen_config.height,
@@ -1166,9 +1175,12 @@ class StableDiffusion:
img = pipeline( img = pipeline(
prompt=None, prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
dtype=self.unet.dtype),
negative_prompt=None, negative_prompt=None,
# negative_prompt=gen_config.negative_prompt, # negative_prompt=gen_config.negative_prompt,
height=gen_config.height, height=gen_config.height,
@@ -1247,9 +1259,7 @@ class StableDiffusion:
if self.model_config.assistant_lora_path is not None: if self.model_config.assistant_lora_path is not None:
print("Loading asistant lora") print("Loading asistant lora")
# unfortunately, not an easier way with peft # unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") self.assistant_lora.is_active = True
self.pipeline.fuse_lora(lora_scale=1.0)
self.pipeline.unload_lora_weights()
def get_latent_noise( def get_latent_noise(
self, self,
@@ -1332,7 +1342,8 @@ class StableDiffusion:
noisy_latents_chunks = [] noisy_latents_chunks = []
for idx in range(original_samples.shape[0]): for idx in range(original_samples.shape[0]):
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx]) noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
timesteps_chunks[idx])
noisy_latents_chunks.append(noisy_latents) noisy_latents_chunks.append(noisy_latents)
noisy_latents = torch.cat(noisy_latents_chunks, dim=0) noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
@@ -1392,7 +1403,6 @@ class StableDiffusion:
else: else:
timestep = timestep.repeat(latents.shape[0], 0) timestep = timestep.repeat(latents.shape[0], 0)
# handle t2i adapters # handle t2i adapters
if 'down_intrablock_additional_residuals' in kwargs: if 'down_intrablock_additional_residuals' in kwargs:
# go through each item and concat if doing cfg and it doesnt have the same shape # go through each item and concat if doing cfg and it doesnt have the same shape
@@ -1561,7 +1571,6 @@ class StableDiffusion:
height = h * VAE_SCALE_FACTOR height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR width = w * VAE_SCALE_FACTOR
if self.pipeline.transformer.config.sample_size == 256: if self.pipeline.transformer.config.sample_size == 256:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.pipeline.transformer.config.sample_size == 128: elif self.pipeline.transformer.config.sample_size == 128:
@@ -1573,10 +1582,12 @@ class StableDiffusion:
else: else:
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
orig_height, orig_width = height, width orig_height, orig_width = height, width
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) height, width = self.pipeline.image_processor.classify_height_width_bin(height, width,
ratios=aspect_ratio_bin)
added_cond_kwargs = {"resolution": None, "aspect_ratio": None} added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.unet.config.sample_size == 128 or (self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): if self.unet.config.sample_size == 128 or (
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
resolution = torch.tensor([height, width]).repeat(batch_size, 1) resolution = torch.tensor([height, width]).repeat(batch_size, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
@@ -1641,7 +1652,8 @@ class StableDiffusion:
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change # todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096] encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype),
# [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
txt_ids=txt_ids, # [1, 512, 3] txt_ids=txt_ids, # [1, 512, 3]
img_ids=img_ids, # [1, 4096, 3] img_ids=img_ids, # [1, 4096, 3]
@@ -1705,7 +1717,7 @@ class StableDiffusion:
with torch.no_grad(): with torch.no_grad():
# do cfg at the target rescale so we can match it # do cfg at the target rescale so we can match it
target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
noise_pred_text - noise_pred_uncond noise_pred_text - noise_pred_uncond
) )
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
@@ -1910,7 +1922,7 @@ class StableDiffusion:
self.text_encoder, self.text_encoder,
prompt, prompt,
truncate=not long_prompts, truncate=not long_prompts,
max_length=77, # todo set this higher when not transfer learning max_length=77, # todo set this higher when not transfer learning
dropout_prob=dropout_prob dropout_prob=dropout_prob
) )
return PromptEmbeds( return PromptEmbeds(
@@ -1957,16 +1969,19 @@ class StableDiffusion:
for i in range(len(image_list)): for i in range(len(image_list)):
image = image_list[i] image = image_list[i]
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
images = torch.stack(image_list) images = torch.stack(image_list)
if isinstance(self.vae, AutoencoderTiny): if isinstance(self.vae, AutoencoderTiny):
latents = self.vae.encode(images, return_dict=False)[0] latents = self.vae.encode(images, return_dict=False)[0]
else: else:
latents = self.vae.encode(images).latent_dist.sample() latents = self.vae.encode(images).latent_dist.sample()
# latents = self.vae.encode(images, return_dict=False)[0]
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latents = latents * (self.vae.config['scaling_factor'] - shift)
# flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303
# z = self.scale_factor * (z - self.shift_factor)
latents = self.vae.config['scaling_factor'] * (latents - shift)
latents = latents.to(device, dtype=dtype) latents = latents.to(device, dtype=dtype)
return latents return latents
@@ -2107,12 +2122,15 @@ class StableDiffusion:
# train the guidance embedding # train the guidance embedding
if self.unet.config.guidance_embeds: if self.unet.config.guidance_embeds:
transformer: FluxTransformer2DModel = self.unet transformer: FluxTransformer2DModel = self.unet
for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): for name, param in transformer.time_text_embed.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param named_params[name] = param
else: else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):