mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
fixed huge flux training bug. Added ability to use an assistatn lora
This commit is contained in:
55
toolkit/assistant_lora.py
Normal file
55
toolkit/assistant_lora.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
|
||||
if not sd.is_flux:
|
||||
raise ValueError("Only Flux models can load assistant adapters currently.")
|
||||
pipe = sd.pipeline
|
||||
print(f"Loading assistant adapter from {adapter_path}")
|
||||
adapter_name = adapter_path.split("/")[-1].split(".")[0]
|
||||
lora_state_dict = load_file(adapter_path)
|
||||
|
||||
linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
|
||||
# linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
|
||||
linear_alpha = linear_dim
|
||||
transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
|
||||
# get dim and scale
|
||||
network_config = NetworkConfig(
|
||||
linear=linear_dim,
|
||||
linear_alpha=linear_alpha,
|
||||
transformer_only=transformer_only,
|
||||
)
|
||||
|
||||
network = LoRASpecialNetwork(
|
||||
text_encoder=pipe.text_encoder,
|
||||
unet=pipe.transformer,
|
||||
lora_dim=network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=network_config.linear_alpha,
|
||||
train_unet=True,
|
||||
train_text_encoder=False,
|
||||
is_flux=True,
|
||||
network_config=network_config,
|
||||
network_type=network_config.type,
|
||||
transformer_only=network_config.transformer_only,
|
||||
is_assistant_adapter=True
|
||||
)
|
||||
network.apply_to(
|
||||
pipe.text_encoder,
|
||||
pipe.transformer,
|
||||
apply_text_encoder=False,
|
||||
apply_unet=True
|
||||
)
|
||||
network.force_to(sd.device_torch, dtype=sd.torch_dtype)
|
||||
network.eval()
|
||||
network._update_torch_multiplier()
|
||||
network.load_weights(lora_state_dict)
|
||||
network.is_active = True
|
||||
|
||||
return network
|
||||
@@ -175,6 +175,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
full_train_in_out: bool = False,
|
||||
transformer_only: bool = False,
|
||||
peft_format: bool = False,
|
||||
is_assistant_adapter: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_auraflow = is_auraflow
|
||||
self.is_flux = is_flux
|
||||
self.network_type = network_type
|
||||
self.is_assistant_adapter = is_assistant_adapter
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
|
||||
@@ -263,7 +263,7 @@ class ToolkitModuleMixin:
|
||||
if isinstance(x, QTensor):
|
||||
x = x.dequantize()
|
||||
# always cast to float32
|
||||
lora_input = x.float()
|
||||
lora_input = x.to(self.lora_down.weight.dtype)
|
||||
lora_output = self._call_forward(lora_input)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from collections import OrderedDict
|
||||
import copy
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \
|
||||
ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch import autocast
|
||||
@@ -20,6 +21,7 @@ from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.assistant_lora import load_assistant_lora_from_path
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -57,6 +59,10 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
@@ -84,7 +90,6 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
|
||||
def __init__(self):
|
||||
@@ -127,10 +132,12 @@ class StableDiffusion:
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = torch.device(self.device)
|
||||
|
||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
|
||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
|
||||
model_config.vae_device)
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
|
||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
|
||||
model_config.te_device)
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
@@ -146,6 +153,7 @@ class StableDiffusion:
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
||||
self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
@@ -462,6 +470,8 @@ class StableDiffusion:
|
||||
print("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
local_files_only = False
|
||||
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
@@ -518,7 +528,8 @@ class StableDiffusion:
|
||||
|
||||
print("Loading t5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
@@ -655,21 +666,17 @@ class StableDiffusion:
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError("Cannot have both lora and assistant lora")
|
||||
print("Loading assistant lora")
|
||||
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.load_refiner()
|
||||
self.is_loaded = True
|
||||
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading assistant lora")
|
||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||
self.model_config.assistant_lora_path, self)
|
||||
|
||||
if self.is_pixart and self.vae_scale_factor == 16:
|
||||
# TODO make our own pipeline?
|
||||
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
||||
@@ -741,9 +748,7 @@ class StableDiffusion:
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Unloading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
self.pipeline.fuse_lora(lora_scale=-1.0)
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
@@ -1027,7 +1032,6 @@ class StableDiffusion:
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
||||
@@ -1035,7 +1039,8 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
||||
if self.adapter is not None and isinstance(self.adapter,
|
||||
CustomAdapter) and validation_image is not None:
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=conditional_embeds,
|
||||
@@ -1052,14 +1057,15 @@ class StableDiffusion:
|
||||
is_generating_samples=True,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(gen_config.extra_values) > 0:
|
||||
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype)
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(
|
||||
gen_config.extra_values) > 0:
|
||||
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch,
|
||||
dtype=self.torch_dtype)
|
||||
# apply extra values to the embeddings
|
||||
self.adapter.add_extra_values(extra_values, is_unconditional=False)
|
||||
self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
|
||||
pass # todo remove, for debugging
|
||||
|
||||
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = gen_config.refiner_start_at
|
||||
@@ -1148,9 +1154,12 @@ class StableDiffusion:
|
||||
img = pipeline(
|
||||
prompt=None,
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt=None,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
height=gen_config.height,
|
||||
@@ -1166,9 +1175,12 @@ class StableDiffusion:
|
||||
img = pipeline(
|
||||
prompt=None,
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
||||
dtype=self.unet.dtype),
|
||||
negative_prompt=None,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
height=gen_config.height,
|
||||
@@ -1247,9 +1259,7 @@ class StableDiffusion:
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
self.pipeline.fuse_lora(lora_scale=1.0)
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.assistant_lora.is_active = True
|
||||
|
||||
def get_latent_noise(
|
||||
self,
|
||||
@@ -1332,7 +1342,8 @@ class StableDiffusion:
|
||||
noisy_latents_chunks = []
|
||||
|
||||
for idx in range(original_samples.shape[0]):
|
||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx])
|
||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
|
||||
timesteps_chunks[idx])
|
||||
noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
@@ -1392,7 +1403,6 @@ class StableDiffusion:
|
||||
else:
|
||||
timestep = timestep.repeat(latents.shape[0], 0)
|
||||
|
||||
|
||||
# handle t2i adapters
|
||||
if 'down_intrablock_additional_residuals' in kwargs:
|
||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||
@@ -1561,7 +1571,6 @@ class StableDiffusion:
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
|
||||
if self.pipeline.transformer.config.sample_size == 256:
|
||||
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
||||
elif self.pipeline.transformer.config.sample_size == 128:
|
||||
@@ -1573,10 +1582,12 @@ class StableDiffusion:
|
||||
else:
|
||||
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width,
|
||||
ratios=aspect_ratio_bin)
|
||||
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.unet.config.sample_size == 128 or (self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
||||
if self.unet.config.sample_size == 128 or (
|
||||
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
||||
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
||||
@@ -1641,7 +1652,8 @@ class StableDiffusion:
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype),
|
||||
# [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
||||
txt_ids=txt_ids, # [1, 512, 3]
|
||||
img_ids=img_ids, # [1, 4096, 3]
|
||||
@@ -1957,16 +1969,19 @@ class StableDiffusion:
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
|
||||
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
if isinstance(self.vae, AutoencoderTiny):
|
||||
latents = self.vae.encode(images, return_dict=False)[0]
|
||||
else:
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latents = latents * (self.vae.config['scaling_factor'] - shift)
|
||||
|
||||
# flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303
|
||||
# z = self.scale_factor * (z - self.shift_factor)
|
||||
latents = self.vae.config['scaling_factor'] * (latents - shift)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
@@ -2107,12 +2122,15 @@ class StableDiffusion:
|
||||
# train the guidance embedding
|
||||
if self.unet.config.guidance_embeds:
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
|
||||
Reference in New Issue
Block a user