Merge pull request #264 from ostris/cogview4

Added basics for CogView4. Broken as hell though. Dont use.
This commit is contained in:
Jaret Burkett
2025-03-05 14:52:06 -07:00
committed by GitHub
14 changed files with 2287 additions and 82 deletions

4
.gitmodules vendored
View File

@@ -1,12 +1,16 @@
[submodule "repositories/sd-scripts"] [submodule "repositories/sd-scripts"]
path = repositories/sd-scripts path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git url = https://github.com/kohya-ss/sd-scripts.git
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
[submodule "repositories/leco"] [submodule "repositories/leco"]
path = repositories/leco path = repositories/leco
url = https://github.com/p1atdev/LECO url = https://github.com/p1atdev/LECO
commit = 9294adf40218e917df4516737afb13f069a6789d
[submodule "repositories/batch_annotator"] [submodule "repositories/batch_annotator"]
path = repositories/batch_annotator path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator url = https://github.com/ostris/batch-annotator
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
[submodule "repositories/ipadapter"] [submodule "repositories/ipadapter"]
path = repositories/ipadapter path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git url = https://github.com/tencent-ailab/IP-Adapter.git
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14

View File

@@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess):
elif self.sd.prediction_type == 'v_prediction': elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training # v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching: elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach() target = (noise - batch.latents).detach()
# reverse ODE
# target = (batch.latents - noise).detach()
else: else:
target = noise target = noise

View File

@@ -68,6 +68,8 @@ import transformers
import diffusers import diffusers
import hashlib import hashlib
from toolkit.util.get_model import get_model_class
def flush(): def flush():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
@@ -666,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# # prepare all the models stuff for accelerator (hopefully we dont miss any) # # prepare all the models stuff for accelerator (hopefully we dont miss any)
self.sd.vae = self.accelerator.prepare(self.sd.vae) self.sd.vae = self.accelerator.prepare(self.sd.vae)
if self.sd.unet is not None: if self.sd.unet is not None:
self.sd.unet_unwrapped = self.sd.unet
self.sd.unet = self.accelerator.prepare(self.sd.unet) self.sd.unet = self.accelerator.prepare(self.sd.unet)
# todo always tdo it? # todo always tdo it?
self.modules_being_trained.append(self.sd.unet) self.modules_being_trained.append(self.sd.unet)
@@ -1103,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if timestep_type is None: if timestep_type is None:
timestep_type = self.train_config.timestep_type timestep_type = self.train_config.timestep_type
patch_size = 1
if self.sd.is_flux:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
patch_size = self.sd.unet.config.patch_size
self.sd.noise_scheduler.set_train_timesteps( self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps, num_train_timesteps,
device=self.device_torch, device=self.device_torch,
timestep_type=timestep_type, timestep_type=timestep_type,
latents=latents latents=latents,
patch_size=patch_size,
) )
else: else:
self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.set_timesteps(
@@ -1401,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.name_or_path = latest_save_path model_config_to_load.name_or_path = latest_save_path
self.load_training_state_from_metadata(latest_save_path) self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler ModelClass = get_model_class(self.model_config)
arch = 'sd' # if the model class has get_train_scheduler static method
if self.model_config.is_pixart: if hasattr(ModelClass, 'get_train_scheduler'):
arch = 'pixart' sampler = ModelClass.get_train_scheduler()
if self.model_config.is_flux: else:
arch = 'flux' # get the noise scheduler
if self.model_config.is_lumina2: arch = 'sd'
arch = 'lumina2' if self.model_config.is_pixart:
sampler = get_sampler( arch = 'pixart'
self.train_config.noise_scheduler, if self.model_config.is_flux:
{ arch = 'flux'
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", if self.model_config.is_lumina2:
}, arch = 'lumina2'
arch=arch, sampler = get_sampler(
) self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch=arch,
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
@@ -1423,7 +1437,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.refiner_name_or_path = previous_refiner_save model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save) self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion( self.sd = ModelClass(
device=self.device, device=self.device,
model_config=model_config_to_load, model_config=model_config_to_load,
dtype=self.train_config.dtype, dtype=self.train_config.dtype,
@@ -1559,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if is_lycoris: # if is_lycoris:
# preset = PRESET['full'] # preset = PRESET['full']
# NetworkClass.apply_preset(preset) # NetworkClass.apply_preset(preset)
if hasattr(self.sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
self.network = NetworkClass( self.network = NetworkClass(
text_encoder=text_encoder, text_encoder=text_encoder,
@@ -1587,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
network_config=self.network_config, network_config=self.network_config,
network_type=self.network_config.type, network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only, transformer_only=self.network_config.transformer_only,
is_transformer=self.sd.is_transformer,
**network_kwargs **network_kwargs
) )

View File

@@ -1,8 +1,8 @@
torch==2.5.1 torch==2.5.1
torchvision==0.20.1 torchvision==0.20.1
safetensors safetensors
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
transformers transformers==4.49.0
lycoris-lora==1.8.3 lycoris-lora==1.8.3
flatten_json flatten_json
pyyaml pyyaml

View File

@@ -29,7 +29,7 @@ def paramiter_count(model):
return int(paramiter_count) return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1): def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device) vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device) lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
# ]) # ])
# needs values between -1 and 1 # needs values between -1 and 1
to_tensor = ToTensor() to_tensor = ToTensor()
# remove _reconstructed.png files
images = [img for img in images if not img.endswith("_reconstructed.png")]
if max_imgs > 0 and len(images) > max_imgs: if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs] images = images[:max_imgs]
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
avg_rfid = 0 avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores) avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores) avg_lpips = sum(lpips_scores) / len(lpips_scores)
if save_output:
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
folder = os.path.dirname(img_path)
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
reconstructed = (reconstructed + 1) / 2
reconstructed = reconstructed.clamp(0, 1)
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
reconstructed.save(save_path)
return avg_rfid, avg_psnr, avg_lpips return avg_rfid, avg_psnr, avg_lpips
@@ -91,18 +103,23 @@ def main():
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
# boolean store true
parser.add_argument("--save_output", action="store_true", help="Save the output images")
args = parser.parse_args() args = parser.parse_args()
if os.path.isfile(args.vae_path): if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path) vae = AutoencoderKL.from_single_file(args.vae_path)
else: else:
vae = AutoencoderKL.from_pretrained(args.vae_path) try:
vae = AutoencoderKL.from_pretrained(args.vae_path)
except:
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.eval() vae.eval()
vae = vae.to(device) vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters") print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder) images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
# print(f"Average rFID: {avg_rfid}") # print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}") print(f"Average PSNR: {avg_psnr}")

View File

@@ -432,6 +432,9 @@ class TrainConfig:
self.force_consistent_noise = kwargs.get('force_consistent_noise', False) self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
class ModelConfig: class ModelConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None) self.name_or_path: str = kwargs.get('name_or_path', None)
@@ -509,6 +512,36 @@ class ModelConfig:
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
self.te_name_or_path = kwargs.get("te_name_or_path", None) self.te_name_or_path = kwargs.get("te_name_or_path", None)
self.arch: ModelArch = kwargs.get("arch", None)
# handle migrating to new model arch
if self.arch is None:
if kwargs.get('is_v2', False):
self.arch = 'sd2'
elif kwargs.get('is_v3', False):
self.arch = 'sd3'
elif kwargs.get('is_xl', False):
self.arch = 'sdxl'
elif kwargs.get('is_pixart', False):
self.arch = 'pixart'
elif kwargs.get('is_pixart_sigma', False):
self.arch = 'pixart_sigma'
elif kwargs.get('is_auraflow', False):
self.arch = 'auraflow'
elif kwargs.get('is_flux', False):
self.arch = 'flux'
elif kwargs.get('is_flex2', False):
self.arch = 'flex2'
elif kwargs.get('is_lumina2', False):
self.arch = 'lumina2'
elif kwargs.get('is_vega', False):
self.arch = 'vega'
elif kwargs.get('is_ssd', False):
self.arch = 'ssd'
else:
self.arch = 'sd1'
class EMAConfig: class EMAConfig:

View File

@@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer_only: bool = False, transformer_only: bool = False,
peft_format: bool = False, peft_format: bool = False,
is_assistant_adapter: bool = False, is_assistant_adapter: bool = False,
is_transformer: bool = False,
**kwargs **kwargs
) -> None: ) -> None:
""" """
@@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.network_config: NetworkConfig = kwargs.get("network_config", None) self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format self.peft_format = peft_format
self.is_transformer = is_transformer
# always do peft for flux only for now # always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2: if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
# don't do peft format for lokr # don't do peft format for lokr
if self.network_type.lower() != "lokr": if self.network_type.lower() != "lokr":
self.peft_format = True self.peft_format = True
@@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
unet_prefix = self.LORA_PREFIX_UNET unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format: if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer:
unet_prefix = f"lora_transformer" unet_prefix = f"lora_transformer"
if self.peft_format: if self.peft_format:
unet_prefix = "transformer" unet_prefix = "transformer"
@@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and self.is_v3 and is_unet: if self.transformer_only and self.is_v3 and is_unet:
if "transformer_blocks" not in lora_name: if "transformer_blocks" not in lora_name:
skip = True skip = True
# handle custom models
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip: if (is_linear or is_conv2d) and not skip:

1426
toolkit/models/base_model.py Normal file

File diff suppressed because it is too large Load Diff

466
toolkit/models/cogview4.py Normal file
View File

@@ -0,0 +1,466 @@
# DONT USE THIS!. IT DOES NOT WORK YET!
# Will revisit this when they release more info on how it was trained.
import weakref
from diffusers import CogView4Pipeline
import torch
import yaml
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
import os
import copy
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
import torch
import diffusers
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from transformers import GlmModel, AutoTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import TYPE_CHECKING
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# remove this after a bug is fixed in diffusers code. This is a workaround.
class FakeModel:
def __init__(self, model):
self.model_ref = weakref.ref(model)
pass
@property
def device(self):
return self.model_ref().device
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.25,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 0.75,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"time_shift_type": "linear",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False
}
class CogView4(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['CogView4Transformer2DModel']
# cache for holding noise
self.effective_noise = None
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def load_model(self):
dtype = self.torch_dtype
base_model_path = "THUDM/CogView4-6B"
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading CogView4 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, 'text_encoder')
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading GlmModel")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = GlmModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing GlmModel")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
# hack to fix diffusers bug workaround
text_encoder.model = FakeModel(text_encoder)
self.print_and_status_update("Loading transformer")
transformer = CogView4Transformer2DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for CogViewModels models")
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for CogViewModels models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for CogViewModels models currently")
flush()
if self.model_config.quantize:
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
if 'include' not in quantization_args:
quantization_args['include'] = []
# Be more specific with the include pattern to exactly match transformer blocks
quantization_args['include'] += ["transformer_blocks.*"]
# Exclude all LayerNorm layers within transformer blocks
quantization_args['exclude'] += [
"transformer_blocks.*.norm1",
"transformer_blocks.*.norm2",
"transformer_blocks.*.norm2_context",
"transformer_blocks.*.attn1.norm_q",
"transformer_blocks.*.attn1.norm_k"
]
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **quantization_args)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = CogView4.get_train_scheduler()
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
self.print_and_status_update("Making pipe")
pipe: CogView4Pipeline = CogView4Pipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
)
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder.to(self.device_torch)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
self.pipeline = pipe
self.model = transformer
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
def get_generation_pipeline(self):
scheduler = CogView4.get_train_scheduler()
pipeline = CogView4Pipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
return pipeline
def generate_single_image(
self,
pipeline: CogView4Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
# target_size = (height, width)
target_size = latent_model_input.shape[-2:]
# multiply by 8
target_size = (target_size[0] * 8, target_size[1] * 8)
crops_coords_top_left = torch.tensor(
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
original_size = torch.tensor(
[target_size], dtype=self.torch_dtype, device=self.device_torch)
target_size = original_size.clone()
noise_pred_cond = self.model(
hidden_states=latent_model_input,
encoder_hidden_states=text_embeddings.text_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
return noise_pred_cond
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
prompt_embeds, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
dtype=self.torch_dtype,
)
return PromptEmbeds(prompt_embeds)
def get_model_has_grad(self):
return self.model.proj_out.weight.requires_grad
def get_te_has_grad(self):
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
effective_noise = self.effective_noise
batch = kwargs.get('batch')
if batch is None:
raise ValueError("Batch is not provided")
if noise is None:
raise ValueError("Noise is not provided")
# return batch.latents
# return (batch.latents - noise).detach()
return (noise - batch.latents).detach()
# return (batch.latents).detach()
# return (effective_noise - batch.latents).detach()
def _get_low_res_latents(self, latents):
# todo prevent needing to do this and grab the tensor another way.
with torch.no_grad():
# Decode latents to image space
images = self.decode_latents(
latents, device=latents.device, dtype=latents.dtype)
# Downsample by a factor of 2 using bilinear interpolation
B, C, H, W = images.shape
low_res_images = torch.nn.functional.interpolate(
images,
size=(H // 2, W // 2),
mode="bilinear",
align_corners=False
)
# Upsample back to original resolution to match expected VAE input dimensions
upsampled_low_res_images = torch.nn.functional.interpolate(
low_res_images,
size=(H, W),
mode="bilinear",
align_corners=False
)
# Encode the low-resolution images back to latent space
low_res_latents = self.encode_images(
upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
return low_res_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents_chunks = None
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# # Flowmatching interpolation between original and noise
# if t > relay_start_point:
# # Standard flowmatching - direct linear interpolation
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
# else:
# # Relay flowmatching case - only compute low_res_latents if needed
# if low_res_latents_chunks is None:
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Calculate the relay ratio (0 to 1)
# t_ratio = t.float() / relay_start_point
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
# # First blend between original and low-res based on t_ratio
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.IntTensor,
# **kwargs,
# ) -> torch.FloatTensor:
# relay_start_point = 500
# # Store original samples for loss calculation
# self.original_samples = original_samples
# # Prepare chunks for batch processing
# original_samples_chunks = torch.chunk(
# original_samples, original_samples.shape[0], dim=0)
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
# # Get the low res latents only if needed
# low_res_latents = self._get_low_res_latents(original_samples)
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
# # Handle case where timesteps is a single value for all samples
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
# noisy_latents_chunks = []
# effective_noise_chunks = [] # Store the effective noise for each sample
# for idx in range(original_samples.shape[0]):
# t = timesteps_chunks[idx]
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
# # lrln = lrln * (1 - t_01)
# # make the noise an interpolation between noise and low_res_latents with
# # being noise at t_01=1 and low_res_latents at t_01=0
# new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
# # new_noise = noise_chunks[idx] + lrln
# # new_noise = noise_chunks[idx] + lrln
# # Then apply flowmatching interpolation between this blended state and noise
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
# # For prediction target, we need to store the effective "source"
# effective_noise_chunks.append(new_noise)
# noisy_latents_chunks.append(noisy_latents)
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
# return noisy_latents

82
toolkit/models/wan21.py Normal file
View File

@@ -0,0 +1,82 @@
# WIP, coming soon ish
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from toolkit.paths import REPOS_ROOT
import sys
import os
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
class Wan21(BaseModel):
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
raise NotImplementedError("Wan21 is not implemented yet")
# these must be implemented in child classes
def load_model(self):
pass
def get_generation_pipeline(self):
# override this in child classes
raise NotImplementedError(
"get_generation_pipeline must be implemented in child classes")
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# override this in child classes
raise NotImplementedError(
"generate_single_image must be implemented in child classes")
def get_noise_prediction(
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")

View File

@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max # flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0 # Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps # Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps # Get the weights for the timesteps
if v2: if v2:
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
sigmas = self.sigmas.to(device=device, dtype=dtype) sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device) schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device) timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = [(schedule_timesteps == t).nonzero().item()
for t in timesteps]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim: while len(sigma.shape) < n_dim:
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
noise: torch.Tensor, noise: torch.Tensor,
timesteps: torch.Tensor, timesteps: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device) t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise # forward ODE
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim # reverse ODE
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample return sample
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): def set_train_timesteps(
self,
num_timesteps,
device,
timestep_type='linear',
latents=None,
patch_size=1
):
self.timestep_type = timestep_type self.timestep_type = timestep_type
if timestep_type == 'linear': if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device) timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device) self.timesteps = timesteps.to(device=device)
return timesteps return timesteps
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift': elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
# matches inference dynamic shifting # matches inference dynamic shifting
timesteps = np.linspace( timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps self._sigma_to_t(self.sigma_max), self._sigma_to_t(
self.sigma_min), num_timesteps
) )
sigmas = timesteps / self.config.num_train_timesteps sigmas = timesteps / self.config.num_train_timesteps
if latents is None:
raise ValueError('latents is None')
h = latents.shape[2] // 2 # Divide by ph
w = latents.shape[3] // 2 # Divide by pw
image_seq_len = h * w
# todo need to know the mu for the shift if self.config.use_dynamic_shifting:
mu = calculate_shift( if latents is None:
image_seq_len, raise ValueError('latents is None')
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # for flux we double up the patch size before sending her to simulate the latent reduction
h = latents.shape[2]
w = latents.shape[3]
image_seq_len = h * w // (patch_size**2)
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
sigmas = torch.from_numpy(sigmas).to(
dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device) self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas self.sigmas = sigmas
self.timesteps = timesteps.to(device=device) self.timesteps = timesteps.to(device=device)
return timesteps return timesteps
elif timestep_type == 'lognorm_blend': elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear # disgtribute timestepd to the center/early and blend in linear
alpha = 0.75 alpha = 0.75
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000) t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear # add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device) t2 = torch.linspace(1000, 0, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2)) timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order # Sort the timesteps in descending order

View File

@@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae convert_vae_state_dict, load_vae
from toolkit import train_tools from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.decorator import Decorator from toolkit.models.decorator import Decorator
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
@@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from toolkit.accelerator import get_accelerator, unwrap_model from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from toolkit.print import print_acc from toolkit.print import print_acc
@@ -160,7 +161,6 @@ class StableDiffusion:
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
self.vae: Union[None, 'AutoencoderKL'] self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel'] self.unet: Union[None, 'UNet2DConditionModel']
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
@@ -177,16 +177,17 @@ class StableDiffusion:
self.network = None self.network = None
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
self.decorator: Union[Decorator, None] = None self.decorator: Union[Decorator, None] = None
self.is_xl = model_config.is_xl self.arch: ModelArch = model_config.arch
self.is_v2 = model_config.is_v2 # self.is_xl = model_config.is_xl
self.is_ssd = model_config.is_ssd # self.is_v2 = model_config.is_v2
self.is_v3 = model_config.is_v3 # self.is_ssd = model_config.is_ssd
self.is_vega = model_config.is_vega # self.is_v3 = model_config.is_v3
self.is_pixart = model_config.is_pixart # self.is_vega = model_config.is_vega
self.is_auraflow = model_config.is_auraflow # self.is_pixart = model_config.is_pixart
self.is_flux = model_config.is_flux # self.is_auraflow = model_config.is_auraflow
self.is_flex2 = model_config.is_flex2 # self.is_flux = model_config.is_flux
self.is_lumina2 = model_config.is_lumina2 # self.is_flex2 = model_config.is_flex2
# self.is_lumina2 = model_config.is_lumina2
self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2 self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -204,6 +205,53 @@ class StableDiffusion:
self.invert_assistant_lora = False self.invert_assistant_lora = False
self._after_sample_img_hooks = [] self._after_sample_img_hooks = []
self._status_update_hooks = [] self._status_update_hooks = []
# todo update this based on the model
self.is_transformer = False
# properties for old arch for backwards compatibility
@property
def is_xl(self):
return self.arch == 'sdxl'
@property
def is_v2(self):
return self.arch == 'sd2'
@property
def is_ssd(self):
return self.arch == 'ssd'
@property
def is_v3(self):
return self.arch == 'sd3'
@property
def is_vega(self):
return self.arch == 'vega'
@property
def is_pixart(self):
return self.arch == 'pixart'
@property
def is_auraflow(self):
return self.arch == 'auraflow'
@property
def is_flux(self):
return self.arch == 'flux'
@property
def is_flex2(self):
return self.arch == 'flex2'
@property
def is_lumina2(self):
return self.arch == 'lumina2'
@property
def unet_unwrapped(self):
return unwrap_model(self.unet)
def load_model(self): def load_model(self):
if self.is_loaded: if self.is_loaded:
@@ -935,7 +983,6 @@ class StableDiffusion:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
# pixart and sd3 dont use a unet # pixart and sd3 dont use a unet
self.unet = pipe.transformer self.unet = pipe.transformer
self.unet_unwrapped = pipe.transformer
else: else:
self.unet: 'UNet2DConditionModel' = pipe.unet self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
@@ -1734,7 +1781,8 @@ class StableDiffusion:
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor timesteps: torch.IntTensor,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)

12
toolkit/util/get_model.py Normal file
View File

@@ -0,0 +1,12 @@
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
elif config.arch == "cogview4":
from toolkit.models.cogview4 import CogView4
return CogView4
else:
return StableDiffusion

55
toolkit/util/quantize.py Normal file
View File

@@ -0,0 +1,55 @@
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Union
import torch
from optimum.quanto.quantize import _quantize_submodule
from optimum.quanto.tensor import Optimizer, qtype
# the quantize function in quanto had a bug where it was using exclude instead of include
def quantize(
model: torch.nn.Module,
weights: Optional[Union[str, qtype]] = None,
activations: Optional[Union[str, qtype]] = None,
optimizer: Optional[Optimizer] = None,
include: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
):
"""Quantize the specified model submodules
Recursively quantize the submodules of the specified parent model.
Only modules that have quantized counterparts will be quantized.
If include patterns are specified, the submodule name must match one of them.
If exclude patterns are specified, the submodule must not match one of them.
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Note: quantization happens in-place and modifies the original model and its descendants.
Args:
model (`torch.nn.Module`): the model whose submodules will be quantized.
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.
include (`Optional[Union[str, List[str]]]`):
Patterns constituting the allowlist. If provided, module names must match at
least one pattern from the allowlist.
exclude (`Optional[Union[str, List[str]]]`):
Patterns constituting the denylist. If provided, module names must not match
any patterns from the denylist.
"""
if include is not None:
include = [include] if isinstance(include, str) else include
if exclude is not None:
exclude = [exclude] if isinstance(exclude, str) else exclude
for name, m in model.named_modules():
if include is not None and not any(fnmatch(name, pattern) for pattern in include):
continue
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
continue
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)