mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
fixed huge flux training bug. Added ability to use an assistatn lora
This commit is contained in:
55
toolkit/assistant_lora.py
Normal file
55
toolkit/assistant_lora.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from toolkit.config_modules import NetworkConfig
|
||||||
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
|
||||||
|
if not sd.is_flux:
|
||||||
|
raise ValueError("Only Flux models can load assistant adapters currently.")
|
||||||
|
pipe = sd.pipeline
|
||||||
|
print(f"Loading assistant adapter from {adapter_path}")
|
||||||
|
adapter_name = adapter_path.split("/")[-1].split(".")[0]
|
||||||
|
lora_state_dict = load_file(adapter_path)
|
||||||
|
|
||||||
|
linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
|
||||||
|
# linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
|
||||||
|
linear_alpha = linear_dim
|
||||||
|
transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
|
||||||
|
# get dim and scale
|
||||||
|
network_config = NetworkConfig(
|
||||||
|
linear=linear_dim,
|
||||||
|
linear_alpha=linear_alpha,
|
||||||
|
transformer_only=transformer_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
network = LoRASpecialNetwork(
|
||||||
|
text_encoder=pipe.text_encoder,
|
||||||
|
unet=pipe.transformer,
|
||||||
|
lora_dim=network_config.linear,
|
||||||
|
multiplier=1.0,
|
||||||
|
alpha=network_config.linear_alpha,
|
||||||
|
train_unet=True,
|
||||||
|
train_text_encoder=False,
|
||||||
|
is_flux=True,
|
||||||
|
network_config=network_config,
|
||||||
|
network_type=network_config.type,
|
||||||
|
transformer_only=network_config.transformer_only,
|
||||||
|
is_assistant_adapter=True
|
||||||
|
)
|
||||||
|
network.apply_to(
|
||||||
|
pipe.text_encoder,
|
||||||
|
pipe.transformer,
|
||||||
|
apply_text_encoder=False,
|
||||||
|
apply_unet=True
|
||||||
|
)
|
||||||
|
network.force_to(sd.device_torch, dtype=sd.torch_dtype)
|
||||||
|
network.eval()
|
||||||
|
network._update_torch_multiplier()
|
||||||
|
network.load_weights(lora_state_dict)
|
||||||
|
network.is_active = True
|
||||||
|
|
||||||
|
return network
|
||||||
@@ -175,6 +175,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
full_train_in_out: bool = False,
|
full_train_in_out: bool = False,
|
||||||
transformer_only: bool = False,
|
transformer_only: bool = False,
|
||||||
peft_format: bool = False,
|
peft_format: bool = False,
|
||||||
|
is_assistant_adapter: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.is_auraflow = is_auraflow
|
self.is_auraflow = is_auraflow
|
||||||
self.is_flux = is_flux
|
self.is_flux = is_flux
|
||||||
self.network_type = network_type
|
self.network_type = network_type
|
||||||
|
self.is_assistant_adapter = is_assistant_adapter
|
||||||
if self.network_type.lower() == "dora":
|
if self.network_type.lower() == "dora":
|
||||||
self.module_class = DoRAModule
|
self.module_class = DoRAModule
|
||||||
module_class = DoRAModule
|
module_class = DoRAModule
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ class ToolkitModuleMixin:
|
|||||||
if isinstance(x, QTensor):
|
if isinstance(x, QTensor):
|
||||||
x = x.dequantize()
|
x = x.dequantize()
|
||||||
# always cast to float32
|
# always cast to float32
|
||||||
lora_input = x.float()
|
lora_input = x.to(self.lora_down.weight.dtype)
|
||||||
lora_output = self._call_forward(lora_input)
|
lora_output = self._call_forward(lora_input)
|
||||||
multiplier = self.network_ref().torch_multiplier
|
multiplier = self.network_ref().torch_multiplier
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from collections import OrderedDict
|
|||||||
import copy
|
import copy
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \
|
||||||
|
ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
from safetensors.torch import save_file, load_file
|
from safetensors.torch import save_file, load_file
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
@@ -20,6 +21,7 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.transforms import Resize, transforms
|
from torchvision.transforms import Resize, transforms
|
||||||
|
|
||||||
|
from toolkit.assistant_lora import load_assistant_lora_from_path
|
||||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||||
from toolkit.custom_adapter import CustomAdapter
|
from toolkit.custom_adapter import CustomAdapter
|
||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
@@ -57,6 +59,10 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
|||||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||||
|
|
||||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
|
||||||
# tell it to shut up
|
# tell it to shut up
|
||||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||||
@@ -84,7 +90,6 @@ DO_NOT_TRAIN_WEIGHTS = [
|
|||||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BlankNetwork:
|
class BlankNetwork:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -127,10 +132,12 @@ class StableDiffusion:
|
|||||||
self.torch_dtype = get_torch_dtype(dtype)
|
self.torch_dtype = get_torch_dtype(dtype)
|
||||||
self.device_torch = torch.device(self.device)
|
self.device_torch = torch.device(self.device)
|
||||||
|
|
||||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
|
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
|
||||||
|
model_config.vae_device)
|
||||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||||
|
|
||||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
|
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
|
||||||
|
model_config.te_device)
|
||||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@@ -146,6 +153,7 @@ class StableDiffusion:
|
|||||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||||
|
|
||||||
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
||||||
|
self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None
|
||||||
|
|
||||||
# sdxl stuff
|
# sdxl stuff
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
@@ -270,7 +278,7 @@ class StableDiffusion:
|
|||||||
# see if path exists
|
# see if path exists
|
||||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||||
try:
|
try:
|
||||||
# try to load with default diffusers
|
# try to load with default diffusers
|
||||||
pipe = pipln.from_pretrained(
|
pipe = pipln.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@@ -462,6 +470,8 @@ class StableDiffusion:
|
|||||||
print("Loading transformer")
|
print("Loading transformer")
|
||||||
subfolder = 'transformer'
|
subfolder = 'transformer'
|
||||||
transformer_path = model_path
|
transformer_path = model_path
|
||||||
|
local_files_only = False
|
||||||
|
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
|
||||||
if os.path.exists(transformer_path):
|
if os.path.exists(transformer_path):
|
||||||
subfolder = None
|
subfolder = None
|
||||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||||
@@ -518,7 +528,8 @@ class StableDiffusion:
|
|||||||
|
|
||||||
print("Loading t5")
|
print("Loading t5")
|
||||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||||
|
torch_dtype=dtype)
|
||||||
|
|
||||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
@@ -655,21 +666,17 @@ class StableDiffusion:
|
|||||||
# unfortunately, not an easier way with peft
|
# unfortunately, not an easier way with peft
|
||||||
pipe.unload_lora_weights()
|
pipe.unload_lora_weights()
|
||||||
|
|
||||||
if self.model_config.assistant_lora_path is not None:
|
|
||||||
if self.model_config.lora_path is not None:
|
|
||||||
raise ValueError("Cannot have both lora and assistant lora")
|
|
||||||
print("Loading assistant lora")
|
|
||||||
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
|
||||||
pipe.fuse_lora(lora_scale=1.0)
|
|
||||||
# unfortunately, not an easier way with peft
|
|
||||||
pipe.unload_lora_weights()
|
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.pipeline = pipe
|
self.pipeline = pipe
|
||||||
self.load_refiner()
|
self.load_refiner()
|
||||||
self.is_loaded = True
|
self.is_loaded = True
|
||||||
|
|
||||||
|
if self.model_config.assistant_lora_path is not None:
|
||||||
|
print("Loading assistant lora")
|
||||||
|
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||||
|
self.model_config.assistant_lora_path, self)
|
||||||
|
|
||||||
if self.is_pixart and self.vae_scale_factor == 16:
|
if self.is_pixart and self.vae_scale_factor == 16:
|
||||||
# TODO make our own pipeline?
|
# TODO make our own pipeline?
|
||||||
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
||||||
@@ -741,9 +748,7 @@ class StableDiffusion:
|
|||||||
if self.model_config.assistant_lora_path is not None:
|
if self.model_config.assistant_lora_path is not None:
|
||||||
print("Unloading asistant lora")
|
print("Unloading asistant lora")
|
||||||
# unfortunately, not an easier way with peft
|
# unfortunately, not an easier way with peft
|
||||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
self.assistant_lora.is_active = False
|
||||||
self.pipeline.fuse_lora(lora_scale=-1.0)
|
|
||||||
self.pipeline.unload_lora_weights()
|
|
||||||
|
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
self.network.eval()
|
self.network.eval()
|
||||||
@@ -1027,7 +1032,6 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||||
and gen_config.adapter_image_path is not None:
|
and gen_config.adapter_image_path is not None:
|
||||||
|
|
||||||
# apply the image projection
|
# apply the image projection
|
||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
||||||
@@ -1035,7 +1039,8 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
if self.adapter is not None and isinstance(self.adapter,
|
||||||
|
CustomAdapter) and validation_image is not None:
|
||||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||||
tensors_0_1=validation_image,
|
tensors_0_1=validation_image,
|
||||||
prompt_embeds=conditional_embeds,
|
prompt_embeds=conditional_embeds,
|
||||||
@@ -1052,13 +1057,14 @@ class StableDiffusion:
|
|||||||
is_generating_samples=True,
|
is_generating_samples=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(gen_config.extra_values) > 0:
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(
|
||||||
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype)
|
gen_config.extra_values) > 0:
|
||||||
|
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch,
|
||||||
|
dtype=self.torch_dtype)
|
||||||
# apply extra values to the embeddings
|
# apply extra values to the embeddings
|
||||||
self.adapter.add_extra_values(extra_values, is_unconditional=False)
|
self.adapter.add_extra_values(extra_values, is_unconditional=False)
|
||||||
self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
|
self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
|
||||||
pass # todo remove, for debugging
|
pass # todo remove, for debugging
|
||||||
|
|
||||||
|
|
||||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||||
@@ -1148,9 +1154,12 @@ class StableDiffusion:
|
|||||||
img = pipeline(
|
img = pipeline(
|
||||||
prompt=None,
|
prompt=None,
|
||||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
||||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
dtype=self.unet.dtype),
|
||||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
||||||
|
dtype=self.unet.dtype),
|
||||||
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
||||||
|
dtype=self.unet.dtype),
|
||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
# negative_prompt=gen_config.negative_prompt,
|
# negative_prompt=gen_config.negative_prompt,
|
||||||
height=gen_config.height,
|
height=gen_config.height,
|
||||||
@@ -1166,9 +1175,12 @@ class StableDiffusion:
|
|||||||
img = pipeline(
|
img = pipeline(
|
||||||
prompt=None,
|
prompt=None,
|
||||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
||||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
dtype=self.unet.dtype),
|
||||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
||||||
|
dtype=self.unet.dtype),
|
||||||
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
||||||
|
dtype=self.unet.dtype),
|
||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
# negative_prompt=gen_config.negative_prompt,
|
# negative_prompt=gen_config.negative_prompt,
|
||||||
height=gen_config.height,
|
height=gen_config.height,
|
||||||
@@ -1247,9 +1259,7 @@ class StableDiffusion:
|
|||||||
if self.model_config.assistant_lora_path is not None:
|
if self.model_config.assistant_lora_path is not None:
|
||||||
print("Loading asistant lora")
|
print("Loading asistant lora")
|
||||||
# unfortunately, not an easier way with peft
|
# unfortunately, not an easier way with peft
|
||||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
self.assistant_lora.is_active = True
|
||||||
self.pipeline.fuse_lora(lora_scale=1.0)
|
|
||||||
self.pipeline.unload_lora_weights()
|
|
||||||
|
|
||||||
def get_latent_noise(
|
def get_latent_noise(
|
||||||
self,
|
self,
|
||||||
@@ -1332,7 +1342,8 @@ class StableDiffusion:
|
|||||||
noisy_latents_chunks = []
|
noisy_latents_chunks = []
|
||||||
|
|
||||||
for idx in range(original_samples.shape[0]):
|
for idx in range(original_samples.shape[0]):
|
||||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx])
|
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
|
||||||
|
timesteps_chunks[idx])
|
||||||
noisy_latents_chunks.append(noisy_latents)
|
noisy_latents_chunks.append(noisy_latents)
|
||||||
|
|
||||||
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||||
@@ -1392,7 +1403,6 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
timestep = timestep.repeat(latents.shape[0], 0)
|
timestep = timestep.repeat(latents.shape[0], 0)
|
||||||
|
|
||||||
|
|
||||||
# handle t2i adapters
|
# handle t2i adapters
|
||||||
if 'down_intrablock_additional_residuals' in kwargs:
|
if 'down_intrablock_additional_residuals' in kwargs:
|
||||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||||
@@ -1561,7 +1571,6 @@ class StableDiffusion:
|
|||||||
height = h * VAE_SCALE_FACTOR
|
height = h * VAE_SCALE_FACTOR
|
||||||
width = w * VAE_SCALE_FACTOR
|
width = w * VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
|
||||||
if self.pipeline.transformer.config.sample_size == 256:
|
if self.pipeline.transformer.config.sample_size == 256:
|
||||||
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
||||||
elif self.pipeline.transformer.config.sample_size == 128:
|
elif self.pipeline.transformer.config.sample_size == 128:
|
||||||
@@ -1573,10 +1582,12 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
||||||
orig_height, orig_width = height, width
|
orig_height, orig_width = height, width
|
||||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width,
|
||||||
|
ratios=aspect_ratio_bin)
|
||||||
|
|
||||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||||
if self.unet.config.sample_size == 128 or (self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
if self.unet.config.sample_size == 128 or (
|
||||||
|
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
||||||
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
||||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
||||||
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
||||||
@@ -1641,7 +1652,8 @@ class StableDiffusion:
|
|||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
# todo make sure this doesnt change
|
# todo make sure this doesnt change
|
||||||
timestep=timestep / 1000, # timestep is 1000 scale
|
timestep=timestep / 1000, # timestep is 1000 scale
|
||||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype),
|
||||||
|
# [1, 512, 4096]
|
||||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
||||||
txt_ids=txt_ids, # [1, 512, 3]
|
txt_ids=txt_ids, # [1, 512, 3]
|
||||||
img_ids=img_ids, # [1, 4096, 3]
|
img_ids=img_ids, # [1, 4096, 3]
|
||||||
@@ -1705,7 +1717,7 @@ class StableDiffusion:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# do cfg at the target rescale so we can match it
|
# do cfg at the target rescale so we can match it
|
||||||
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
|
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
|
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
|
||||||
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
|
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
|
||||||
@@ -1910,7 +1922,7 @@ class StableDiffusion:
|
|||||||
self.text_encoder,
|
self.text_encoder,
|
||||||
prompt,
|
prompt,
|
||||||
truncate=not long_prompts,
|
truncate=not long_prompts,
|
||||||
max_length=77, # todo set this higher when not transfer learning
|
max_length=77, # todo set this higher when not transfer learning
|
||||||
dropout_prob=dropout_prob
|
dropout_prob=dropout_prob
|
||||||
)
|
)
|
||||||
return PromptEmbeds(
|
return PromptEmbeds(
|
||||||
@@ -1957,16 +1969,19 @@ class StableDiffusion:
|
|||||||
for i in range(len(image_list)):
|
for i in range(len(image_list)):
|
||||||
image = image_list[i]
|
image = image_list[i]
|
||||||
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
||||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
|
||||||
|
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||||
|
|
||||||
images = torch.stack(image_list)
|
images = torch.stack(image_list)
|
||||||
if isinstance(self.vae, AutoencoderTiny):
|
if isinstance(self.vae, AutoencoderTiny):
|
||||||
latents = self.vae.encode(images, return_dict=False)[0]
|
latents = self.vae.encode(images, return_dict=False)[0]
|
||||||
else:
|
else:
|
||||||
latents = self.vae.encode(images).latent_dist.sample()
|
latents = self.vae.encode(images).latent_dist.sample()
|
||||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
|
||||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||||
latents = latents * (self.vae.config['scaling_factor'] - shift)
|
|
||||||
|
# flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303
|
||||||
|
# z = self.scale_factor * (z - self.shift_factor)
|
||||||
|
latents = self.vae.config['scaling_factor'] * (latents - shift)
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(device, dtype=dtype)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
@@ -2107,12 +2122,15 @@ class StableDiffusion:
|
|||||||
# train the guidance embedding
|
# train the guidance embedding
|
||||||
if self.unet.config.guidance_embeds:
|
if self.unet.config.guidance_embeds:
|
||||||
transformer: FluxTransformer2DModel = self.unet
|
transformer: FluxTransformer2DModel = self.unet
|
||||||
for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
||||||
|
prefix=f"{SD_PREFIX_UNET}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
|
|
||||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||||
|
prefix=f"{SD_PREFIX_UNET}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
||||||
|
prefix=f"{SD_PREFIX_UNET}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
else:
|
else:
|
||||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
|||||||
Reference in New Issue
Block a user