Added cogview4. Loss still needs work.

This commit is contained in:
Jaret Burkett
2025-03-04 18:43:52 -07:00
parent c57434ad7b
commit 6f6fb90812
12 changed files with 661 additions and 152 deletions

View File

@@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess):
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
# reverse ODE
# target = (batch.latents - noise).detach()
else:
target = noise

View File

@@ -668,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
self.sd.vae = self.accelerator.prepare(self.sd.vae)
if self.sd.unet is not None:
self.sd.unet_unwrapped = self.sd.unet
self.sd.unet = self.accelerator.prepare(self.sd.unet)
# todo always tdo it?
self.modules_being_trained.append(self.sd.unet)
@@ -1105,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if timestep_type is None:
timestep_type = self.train_config.timestep_type
patch_size = 1
if self.sd.is_flux:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
patch_size = self.sd.unet.config.patch_size
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
timestep_type=timestep_type,
latents=latents
latents=latents,
patch_size=patch_size,
)
else:
self.sd.noise_scheduler.set_timesteps(
@@ -1403,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.name_or_path = latest_save_path
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
ModelClass = get_model_class(self.model_config)
# if the model class has get_train_scheduler static method
if hasattr(ModelClass, 'get_train_scheduler'):
sampler = ModelClass.get_train_scheduler()
else:
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
@@ -1425,7 +1437,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
ModelClass = get_model_class(self.model_config)
self.sd = ModelClass(
device=self.device,
model_config=model_config_to_load,
@@ -1562,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
if hasattr(self.sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
self.network = NetworkClass(
text_encoder=text_encoder,
@@ -1590,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
is_transformer=self.sd.is_transformer,
**network_kwargs
)

View File

@@ -1,8 +1,8 @@
torch==2.5.1
torchvision==0.20.1
safetensors
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
transformers
git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56
transformers==4.49.0
lycoris-lora==1.8.3
flatten_json
pyyaml

View File

@@ -29,7 +29,7 @@ def paramiter_count(model):
return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1):
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
# ])
# needs values between -1 and 1
to_tensor = ToTensor()
# remove _reconstructed.png files
images = [img for img in images if not img.endswith("_reconstructed.png")]
if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs]
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores)
if save_output:
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
folder = os.path.dirname(img_path)
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
reconstructed = (reconstructed + 1) / 2
reconstructed = reconstructed.clamp(0, 1)
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
reconstructed.save(save_path)
return avg_rfid, avg_psnr, avg_lpips
@@ -91,18 +103,23 @@ def main():
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
# boolean store true
parser.add_argument("--save_output", action="store_true", help="Save the output images")
args = parser.parse_args()
if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path)
try:
vae = AutoencoderKL.from_pretrained(args.vae_path)
except:
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.eval()
vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
# print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}")

View File

@@ -513,7 +513,7 @@ class ModelConfig:
self.te_name_or_path = kwargs.get("te_name_or_path", None)
self.arch: ModelArch = kwargs.get("model_arch", None)
self.arch: ModelArch = kwargs.get("arch", None)
# handle migrating to new model arch
if self.arch is None:

View File

@@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer_only: bool = False,
peft_format: bool = False,
is_assistant_adapter: bool = False,
is_transformer: bool = False,
**kwargs
) -> None:
"""
@@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format
self.is_transformer = is_transformer
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2:
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
self.peft_format = True
@@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer:
unet_prefix = f"lora_transformer"
if self.peft_format:
unet_prefix = "transformer"
@@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and self.is_v3 and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
# handle custom models
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:

View File

@@ -168,11 +168,17 @@ class BaseModel:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
self.is_transformer = False
# properties for old arch for backwards compatibility
@property
def unet(self):
return self.model
# set unet to model
@unet.setter
def unet(self, value):
self.model = value
@property
def unet_unwrapped(self):
@@ -235,6 +241,7 @@ class BaseModel:
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
@@ -257,6 +264,25 @@ class BaseModel:
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")
def save_model(self, output_path, meta, save_dtype):
# todo handle dtype without overloading anything (vram, cpu, etc)
unwrap_model(self.pipeline).save_pretrained(
save_directory=output_path,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
# end must be implemented in child classes
def te_train(self):
@@ -512,6 +538,7 @@ class BaseModel:
self.device_torch, dtype=self.unet.dtype)
img = self.generate_single_image(
pipeline,
gen_config,
conditional_embeds,
unconditional_embeds,
@@ -603,7 +630,8 @@ class BaseModel:
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor
timesteps: torch.IntTensor,
**kwargs,
) -> torch.FloatTensor:
original_samples_chunks = torch.chunk(
original_samples, original_samples.shape[0], dim=0)
@@ -1071,7 +1099,7 @@ class BaseModel:
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
named_params[name] = param
if unet:
if self.is_flux or self.is_lumina2:
if self.is_flux or self.is_lumina2 or self.is_transformer:
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
named_params[name] = param
else:
@@ -1105,59 +1133,11 @@ class BaseModel:
return named_params
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
version_string = '1'
if self.is_v2:
version_string = '2'
if self.is_xl:
version_string = 'sdxl'
if self.is_ssd:
# overwrite sdxl because both wil be true here
version_string = 'ssd'
if self.is_ssd and self.is_vega:
version_string = 'vega'
# if output file does not end in .safetensors, then it is a directory and we are
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,
)
elif self.is_lumina2:
# only save the unet
transformer: Lumina2Transformer2DModel = unwrap_model(
self.unet)
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
else:
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version=version_string,
)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
self.save_model(
output_path=output_file,
meta=meta,
save_dtype=save_dtype
)
def prepare_optimizer_params(
self,
@@ -1240,12 +1220,7 @@ class BaseModel:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_lumina2:
unet_has_grad = self.unet.x_embedder.weight.requires_grad
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
unet_has_grad = self.get_model_has_grad()
self.device_state = {
**empty_preset,
@@ -1262,13 +1237,7 @@ class BaseModel:
if isinstance(self.text_encoder, list):
self.device_state['text_encoder']: List[dict] = []
for encoder in self.text_encoder:
if isinstance(encoder, LlamaModel):
te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad
else:
try:
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
except:
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
te_has_grad = self.get_te_has_grad()
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
@@ -1276,17 +1245,7 @@ class BaseModel:
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
te_has_grad = self.text_encoder.encoder.block[
0].layer[0].SelfAttention.q.weight.requires_grad
elif isinstance(self.text_encoder, Gemma2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, Qwen2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, LlamaModel):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
te_has_grad = self.get_te_has_grad()
self.device_state['text_encoder'] = {
'training': self.text_encoder.training,

458
toolkit/models/cogview4.py Normal file
View File

@@ -0,0 +1,458 @@
import weakref
from diffusers import CogView4Pipeline
import torch
import yaml
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
import os
import copy
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
import torch
import diffusers
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# remove this after a bug is fixed in diffusers code. This is a workaround.
class FakeModel:
def __init__(self, model):
self.model_ref = weakref.ref(model)
pass
@property
def device(self):
return self.model_ref().device
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.25,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 0.75,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"time_shift_type": "linear",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False
}
class CogView4(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['CogView4Transformer2DModel']
# cache for holding noise
self.effective_noise = None
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def load_model(self):
dtype = self.torch_dtype
base_model_path = "THUDM/CogView4-6B"
model_path = self.model_config.name_or_path
# pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
self.print_and_status_update("Loading CogView4 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, 'text_encoder')
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading GlmModel")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = GlmModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing GlmModel")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
# hack to fix diffusers bug workaround
text_encoder.model = FakeModel(text_encoder)
self.print_and_status_update("Loading transformer")
transformer = CogView4Transformer2DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for CogViewModels models")
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for CogViewModels models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for CogViewModels models currently")
flush()
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = CogView4.get_train_scheduler()
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
self.print_and_status_update("Making pipe")
pipe: CogView4Pipeline = CogView4Pipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
)
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder.to(self.device_torch)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
self.pipeline = pipe
self.model = transformer
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
def get_generation_pipeline(self):
scheduler = CogView4.get_train_scheduler()
pipeline = CogView4Pipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
return pipeline
def generate_single_image(
self,
pipeline: CogView4Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional
# they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine.
if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]:
pad_len = unconditional_embeds.text_embeds.shape[1] - \
conditional_embeds.text_embeds.shape[1]
conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len,
conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1)
elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]:
pad_len = conditional_embeds.text_embeds.shape[1] - \
unconditional_embeds.text_embeds.shape[1]
unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len,
unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1)
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
# target_size = (height, width)
target_size = latent_model_input.shape[-2:]
# multiply by 8
target_size = (target_size[0] * 8, target_size[1] * 8)
crops_coords_top_left = torch.tensor(
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
original_size = torch.tensor(
[target_size], dtype=self.torch_dtype, device=self.device_torch)
target_size = original_size.clone()
noise_pred_cond = self.model(
hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128])
encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096])
timestep=timestep,
original_size=original_size, # [[1024., 1024.]]
target_size=target_size, # [[1024., 1024.]]
crop_coords=crops_coords_top_left, # [[0., 0.]]
return_dict=False,
)[0]
return noise_pred_cond
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
prompt_embeds, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
dtype=self.torch_dtype,
)
return PromptEmbeds(prompt_embeds)
def get_model_has_grad(self):
return self.model.proj_out.weight.requires_grad
def get_te_has_grad(self):
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
effective_noise = self.effective_noise
batch = kwargs.get('batch')
if batch is None:
raise ValueError("Batch is not provided")
if noise is None:
raise ValueError("Noise is not provided")
# return batch.latents
return (noise - batch.latents).detach()
# return (effective_noise - batch.latents).detach()
def _get_low_res_latents(self, latents):
# todo prevent needing to do this and grab the tensor another way.
with torch.no_grad():
# Decode latents to image space
images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype)
# Downsample by a factor of 2 using bilinear interpolation
B, C, H, W = images.shape
low_res_images = torch.nn.functional.interpolate(
images,
size=(H // 4, W // 4),
mode="bilinear",
align_corners=False
)
# Upsample back to original resolution to match expected VAE input dimensions
upsampled_low_res_images = torch.nn.functional.interpolate(
low_res_images,
size=(H, W),
mode="bilinear",
align_corners=False
)
# Encode the low-resolution images back to latent space
low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
return low_res_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents_chunks = None
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# # Flowmatching interpolation between original and noise
# if t > relay_start_point:
# # Standard flowmatching - direct linear interpolation
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
# else:
# # Relay flowmatching case - only compute low_res_latents if needed
# if low_res_latents_chunks is None:
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Calculate the relay ratio (0 to 1)
# t_ratio = t.float() / relay_start_point
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
# # First blend between original and low-res based on t_ratio
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
# lrln = lrln * (1 - t_01)
# # make the noise an interpolation between noise and low_res_latents with
# # being noise at t_01=1 and low_res_latents at t_01=0
# # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
# new_noise = noise_chunks[idx] + lrln
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(new_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents

View File

@@ -36,12 +36,11 @@ class Wan21(BaseModel):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
raise NotImplementedError("Wan21 is not implemented yet")
# these must be implemented in child classes
def load_model(self):
self.pipeline = Wan21(
)
pass
def get_generation_pipeline(self):
# override this in child classes
@@ -50,6 +49,7 @@ class Wan21(BaseModel):
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
@@ -72,3 +72,11 @@ class Wan21(BaseModel):
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")

View File

@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if v2:
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(schedule_timesteps == t).nonzero().item()
for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
# forward ODE
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
# reverse ODE
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
def set_train_timesteps(
self,
num_timesteps,
device,
timestep_type='linear',
latents=None,
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
# matches inference dynamic shifting
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
self.sigma_min), num_timesteps
)
sigmas = timesteps / self.config.num_train_timesteps
if latents is None:
raise ValueError('latents is None')
h = latents.shape[2] // 2 # Divide by ph
w = latents.shape[3] // 2 # Divide by pw
image_seq_len = h * w
# todo need to know the mu for the shift
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
if self.config.use_dynamic_shifting:
if latents is None:
raise ValueError('latents is None')
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
# for flux we double up the patch size before sending her to simulate the latent reduction
h = latents.shape[2]
w = latents.shape[3]
image_seq_len = h * w // (patch_size**2)
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
sigmas = torch.from_numpy(sigmas).to(
dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.75
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
t2 = torch.linspace(1000, 0, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order

View File

@@ -160,7 +160,6 @@ class StableDiffusion:
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
@@ -205,6 +204,8 @@ class StableDiffusion:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
# todo update this based on the model
self.is_transformer = False
# properties for old arch for backwards compatibility
@property
@@ -246,6 +247,10 @@ class StableDiffusion:
@property
def is_lumina2(self):
return self.arch == 'lumina2'
@property
def unet_unwrapped(self):
return unwrap_model(self.unet)
def load_model(self):
if self.is_loaded:
@@ -977,7 +982,6 @@ class StableDiffusion:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
self.unet_unwrapped = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
@@ -1776,7 +1780,8 @@ class StableDiffusion:
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor
timesteps: torch.IntTensor,
**kwargs,
) -> torch.FloatTensor:
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)

View File

@@ -5,5 +5,8 @@ def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
elif config.arch == "cogview4":
from toolkit.models.cogview4 import CogView4
return CogView4
else:
return StableDiffusion