Merge pull request #264 from ostris/cogview4

Added basics for CogView4. Broken as hell though. Dont use.
This commit is contained in:
Jaret Burkett
2025-03-05 14:52:06 -07:00
committed by GitHub
14 changed files with 2287 additions and 82 deletions

4
.gitmodules vendored
View File

@@ -1,12 +1,16 @@
[submodule "repositories/sd-scripts"]
path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
[submodule "repositories/leco"]
path = repositories/leco
url = https://github.com/p1atdev/LECO
commit = 9294adf40218e917df4516737afb13f069a6789d
[submodule "repositories/batch_annotator"]
path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
[submodule "repositories/ipadapter"]
path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14

View File

@@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess):
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
# reverse ODE
# target = (batch.latents - noise).detach()
else:
target = noise

View File

@@ -68,6 +68,8 @@ import transformers
import diffusers
import hashlib
from toolkit.util.get_model import get_model_class
def flush():
torch.cuda.empty_cache()
gc.collect()
@@ -666,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
self.sd.vae = self.accelerator.prepare(self.sd.vae)
if self.sd.unet is not None:
self.sd.unet_unwrapped = self.sd.unet
self.sd.unet = self.accelerator.prepare(self.sd.unet)
# todo always tdo it?
self.modules_being_trained.append(self.sd.unet)
@@ -1103,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if timestep_type is None:
timestep_type = self.train_config.timestep_type
patch_size = 1
if self.sd.is_flux:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
patch_size = self.sd.unet.config.patch_size
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
timestep_type=timestep_type,
latents=latents
latents=latents,
patch_size=patch_size,
)
else:
self.sd.noise_scheduler.set_timesteps(
@@ -1401,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.name_or_path = latest_save_path
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
ModelClass = get_model_class(self.model_config)
# if the model class has get_train_scheduler static method
if hasattr(ModelClass, 'get_train_scheduler'):
sampler = ModelClass.get_train_scheduler()
else:
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
@@ -1423,7 +1437,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion(
self.sd = ModelClass(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
@@ -1559,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
if hasattr(self.sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
self.network = NetworkClass(
text_encoder=text_encoder,
@@ -1587,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
is_transformer=self.sd.is_transformer,
**network_kwargs
)

View File

@@ -1,8 +1,8 @@
torch==2.5.1
torchvision==0.20.1
safetensors
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
transformers
git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
transformers==4.49.0
lycoris-lora==1.8.3
flatten_json
pyyaml

View File

@@ -29,7 +29,7 @@ def paramiter_count(model):
return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1):
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
# ])
# needs values between -1 and 1
to_tensor = ToTensor()
# remove _reconstructed.png files
images = [img for img in images if not img.endswith("_reconstructed.png")]
if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs]
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores)
if save_output:
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
folder = os.path.dirname(img_path)
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
reconstructed = (reconstructed + 1) / 2
reconstructed = reconstructed.clamp(0, 1)
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
reconstructed.save(save_path)
return avg_rfid, avg_psnr, avg_lpips
@@ -91,18 +103,23 @@ def main():
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
# boolean store true
parser.add_argument("--save_output", action="store_true", help="Save the output images")
args = parser.parse_args()
if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path)
try:
vae = AutoencoderKL.from_pretrained(args.vae_path)
except:
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.eval()
vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
# print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}")

View File

@@ -432,6 +432,9 @@ class TrainConfig:
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
@@ -509,6 +512,36 @@ class ModelConfig:
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
self.te_name_or_path = kwargs.get("te_name_or_path", None)
self.arch: ModelArch = kwargs.get("arch", None)
# handle migrating to new model arch
if self.arch is None:
if kwargs.get('is_v2', False):
self.arch = 'sd2'
elif kwargs.get('is_v3', False):
self.arch = 'sd3'
elif kwargs.get('is_xl', False):
self.arch = 'sdxl'
elif kwargs.get('is_pixart', False):
self.arch = 'pixart'
elif kwargs.get('is_pixart_sigma', False):
self.arch = 'pixart_sigma'
elif kwargs.get('is_auraflow', False):
self.arch = 'auraflow'
elif kwargs.get('is_flux', False):
self.arch = 'flux'
elif kwargs.get('is_flex2', False):
self.arch = 'flex2'
elif kwargs.get('is_lumina2', False):
self.arch = 'lumina2'
elif kwargs.get('is_vega', False):
self.arch = 'vega'
elif kwargs.get('is_ssd', False):
self.arch = 'ssd'
else:
self.arch = 'sd1'
class EMAConfig:

View File

@@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer_only: bool = False,
peft_format: bool = False,
is_assistant_adapter: bool = False,
is_transformer: bool = False,
**kwargs
) -> None:
"""
@@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format
self.is_transformer = is_transformer
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2:
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
self.peft_format = True
@@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer:
unet_prefix = f"lora_transformer"
if self.peft_format:
unet_prefix = "transformer"
@@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and self.is_v3 and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
# handle custom models
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:

1426
toolkit/models/base_model.py Normal file

File diff suppressed because it is too large Load Diff

466
toolkit/models/cogview4.py Normal file
View File

@@ -0,0 +1,466 @@
# DONT USE THIS!. IT DOES NOT WORK YET!
# Will revisit this when they release more info on how it was trained.
import weakref
from diffusers import CogView4Pipeline
import torch
import yaml
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
import os
import copy
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
import torch
import diffusers
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# remove this after a bug is fixed in diffusers code. This is a workaround.
class FakeModel:
def __init__(self, model):
self.model_ref = weakref.ref(model)
pass
@property
def device(self):
return self.model_ref().device
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.25,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 0.75,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"time_shift_type": "linear",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False
}
class CogView4(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['CogView4Transformer2DModel']
# cache for holding noise
self.effective_noise = None
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def load_model(self):
dtype = self.torch_dtype
base_model_path = "THUDM/CogView4-6B"
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading CogView4 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, 'text_encoder')
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading GlmModel")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = GlmModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing GlmModel")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
# hack to fix diffusers bug workaround
text_encoder.model = FakeModel(text_encoder)
self.print_and_status_update("Loading transformer")
transformer = CogView4Transformer2DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for CogViewModels models")
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for CogViewModels models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for CogViewModels models currently")
flush()
if self.model_config.quantize:
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
if 'include' not in quantization_args:
quantization_args['include'] = []
# Be more specific with the include pattern to exactly match transformer blocks
quantization_args['include'] += ["transformer_blocks.*"]
# Exclude all LayerNorm layers within transformer blocks
quantization_args['exclude'] += [
"transformer_blocks.*.norm1",
"transformer_blocks.*.norm2",
"transformer_blocks.*.norm2_context",
"transformer_blocks.*.attn1.norm_q",
"transformer_blocks.*.attn1.norm_k"
]
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **quantization_args)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = CogView4.get_train_scheduler()
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
self.print_and_status_update("Making pipe")
pipe: CogView4Pipeline = CogView4Pipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
)
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder.to(self.device_torch)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
self.pipeline = pipe
self.model = transformer
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
def get_generation_pipeline(self):
scheduler = CogView4.get_train_scheduler()
pipeline = CogView4Pipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
return pipeline
def generate_single_image(
self,
pipeline: CogView4Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
# target_size = (height, width)
target_size = latent_model_input.shape[-2:]
# multiply by 8
target_size = (target_size[0] * 8, target_size[1] * 8)
crops_coords_top_left = torch.tensor(
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
original_size = torch.tensor(
[target_size], dtype=self.torch_dtype, device=self.device_torch)
target_size = original_size.clone()
noise_pred_cond = self.model(
hidden_states=latent_model_input,
encoder_hidden_states=text_embeddings.text_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
return noise_pred_cond
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
prompt_embeds, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
dtype=self.torch_dtype,
)
return PromptEmbeds(prompt_embeds)
def get_model_has_grad(self):
return self.model.proj_out.weight.requires_grad
def get_te_has_grad(self):
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
effective_noise = self.effective_noise
batch = kwargs.get('batch')
if batch is None:
raise ValueError("Batch is not provided")
if noise is None:
raise ValueError("Noise is not provided")
# return batch.latents
# return (batch.latents - noise).detach()
return (noise - batch.latents).detach()
# return (batch.latents).detach()
# return (effective_noise - batch.latents).detach()
def _get_low_res_latents(self, latents):
# todo prevent needing to do this and grab the tensor another way.
with torch.no_grad():
# Decode latents to image space
images = self.decode_latents(
latents, device=latents.device, dtype=latents.dtype)
# Downsample by a factor of 2 using bilinear interpolation
B, C, H, W = images.shape
low_res_images = torch.nn.functional.interpolate(
images,
size=(H // 2, W // 2),
mode="bilinear",
align_corners=False
)
# Upsample back to original resolution to match expected VAE input dimensions
upsampled_low_res_images = torch.nn.functional.interpolate(
low_res_images,
size=(H, W),
mode="bilinear",
align_corners=False
)
# Encode the low-resolution images back to latent space
low_res_latents = self.encode_images(
upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
return low_res_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents_chunks = None
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# # Flowmatching interpolation between original and noise
# if t > relay_start_point:
# # Standard flowmatching - direct linear interpolation
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
# else:
# # Relay flowmatching case - only compute low_res_latents if needed
# if low_res_latents_chunks is None:
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Calculate the relay ratio (0 to 1)
# t_ratio = t.float() / relay_start_point
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
# # First blend between original and low-res based on t_ratio
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
# # lrln = lrln * (1 - t_01)
# # make the noise an interpolation between noise and low_res_latents with
# # being noise at t_01=1 and low_res_latents at t_01=0
# new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
# # new_noise = noise_chunks[idx] + lrln
# # new_noise = noise_chunks[idx] + lrln
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(new_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents

82
toolkit/models/wan21.py Normal file
View File

@@ -0,0 +1,82 @@
# WIP, coming soon ish
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from toolkit.paths import REPOS_ROOT
import sys
import os
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
class Wan21(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
raise NotImplementedError("Wan21 is not implemented yet")
# these must be implemented in child classes
def load_model(self):
pass
def get_generation_pipeline(self):
# override this in child classes
raise NotImplementedError(
"get_generation_pipeline must be implemented in child classes")
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# override this in child classes
raise NotImplementedError(
"generate_single_image must be implemented in child classes")
def get_noise_prediction(
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")

View File

@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if v2:
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(schedule_timesteps == t).nonzero().item()
for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
# forward ODE
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
# reverse ODE
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
def set_train_timesteps(
self,
num_timesteps,
device,
timestep_type='linear',
latents=None,
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
# matches inference dynamic shifting
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
self.sigma_min), num_timesteps
)
sigmas = timesteps / self.config.num_train_timesteps
if latents is None:
raise ValueError('latents is None')
h = latents.shape[2] // 2 # Divide by ph
w = latents.shape[3] // 2 # Divide by pw
image_seq_len = h * w
# todo need to know the mu for the shift
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
if self.config.use_dynamic_shifting:
if latents is None:
raise ValueError('latents is None')
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
# for flux we double up the patch size before sending her to simulate the latent reduction
h = latents.shape[2]
w = latents.shape[3]
image_seq_len = h * w // (patch_size**2)
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
sigmas = torch.from_numpy(sigmas).to(
dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.75
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
t2 = torch.linspace(1000, 0, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order

View File

@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
@@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
from toolkit.print import print_acc
@@ -160,7 +161,6 @@ class StableDiffusion:
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
@@ -177,16 +177,17 @@ class StableDiffusion:
self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.is_flux = model_config.is_flux
self.is_flex2 = model_config.is_flex2
self.is_lumina2 = model_config.is_lumina2
self.arch: ModelArch = model_config.arch
# self.is_xl = model_config.is_xl
# self.is_v2 = model_config.is_v2
# self.is_ssd = model_config.is_ssd
# self.is_v3 = model_config.is_v3
# self.is_vega = model_config.is_vega
# self.is_pixart = model_config.is_pixart
# self.is_auraflow = model_config.is_auraflow
# self.is_flux = model_config.is_flux
# self.is_flex2 = model_config.is_flex2
# self.is_lumina2 = model_config.is_lumina2
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -204,6 +205,53 @@ class StableDiffusion:
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
# todo update this based on the model
self.is_transformer = False
# properties for old arch for backwards compatibility
@property
def is_xl(self):
return self.arch == 'sdxl'
@property
def is_v2(self):
return self.arch == 'sd2'
@property
def is_ssd(self):
return self.arch == 'ssd'
@property
def is_v3(self):
return self.arch == 'sd3'
@property
def is_vega(self):
return self.arch == 'vega'
@property
def is_pixart(self):
return self.arch == 'pixart'
@property
def is_auraflow(self):
return self.arch == 'auraflow'
@property
def is_flux(self):
return self.arch == 'flux'
@property
def is_flex2(self):
return self.arch == 'flex2'
@property
def is_lumina2(self):
return self.arch == 'lumina2'
@property
def unet_unwrapped(self):
return unwrap_model(self.unet)
def load_model(self):
if self.is_loaded:
@@ -935,7 +983,6 @@ class StableDiffusion:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
self.unet_unwrapped = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
@@ -1734,7 +1781,8 @@ class StableDiffusion:
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor
timesteps: torch.IntTensor,
**kwargs,
) -> torch.FloatTensor:
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)

12
toolkit/util/get_model.py Normal file
View File

@@ -0,0 +1,12 @@
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
elif config.arch == "cogview4":
from toolkit.models.cogview4 import CogView4
return CogView4
else:
return StableDiffusion

55
toolkit/util/quantize.py Normal file
View File

@@ -0,0 +1,55 @@
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Union
import torch
from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.tensor import Optimizer, qtype
# the quantize function in quanto had a bug where it was using exclude instead of include
def quantize(
model: torch.nn.Module,
weights: Optional[Union[str, qtype]] = None,
activations: Optional[Union[str, qtype]] = None,
optimizer: Optional[Optimizer] = None,
include: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
):
"""Quantize the specified model submodules
Recursively quantize the submodules of the specified parent model.
Only modules that have quantized counterparts will be quantized.
If include patterns are specified, the submodule name must match one of them.
If exclude patterns are specified, the submodule must not match one of them.
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Note: quantization happens in-place and modifies the original model and its descendants.
Args:
model (`torch.nn.Module`): the model whose submodules will be quantized.
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
include (`Optional[Union[str, List[str]]]`):
Patterns constituting the allowlist. If provided, module names must match at
least one pattern from the allowlist.
exclude (`Optional[Union[str, List[str]]]`):
Patterns constituting the denylist. If provided, module names must not match
any patterns from the denylist.
"""
if include is not None:
include = [include] if isinstance(include, str) else include
if exclude is not None:
exclude = [exclude] if isinstance(exclude, str) else exclude
for name, m in model.named_modules():
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
continue
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
continue
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)