mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added cogview4. Loss still needs work.
This commit is contained in:
@@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
|
||||
|
||||
elif hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# forward ODE
|
||||
target = (noise - batch.latents).detach()
|
||||
# reverse ODE
|
||||
# target = (batch.latents - noise).detach()
|
||||
else:
|
||||
target = noise
|
||||
|
||||
|
||||
@@ -668,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
|
||||
self.sd.vae = self.accelerator.prepare(self.sd.vae)
|
||||
if self.sd.unet is not None:
|
||||
self.sd.unet_unwrapped = self.sd.unet
|
||||
self.sd.unet = self.accelerator.prepare(self.sd.unet)
|
||||
# todo always tdo it?
|
||||
self.modules_being_trained.append(self.sd.unet)
|
||||
@@ -1105,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if timestep_type is None:
|
||||
timestep_type = self.train_config.timestep_type
|
||||
|
||||
patch_size = 1
|
||||
if self.sd.is_flux:
|
||||
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
|
||||
patch_size = 2
|
||||
elif hasattr(self.sd.unet.config, 'patch_size'):
|
||||
patch_size = self.sd.unet.config.patch_size
|
||||
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
num_train_timesteps,
|
||||
device=self.device_torch,
|
||||
timestep_type=timestep_type,
|
||||
latents=latents
|
||||
latents=latents,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
@@ -1403,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
# get the noise scheduler
|
||||
arch = 'sd'
|
||||
if self.model_config.is_pixart:
|
||||
arch = 'pixart'
|
||||
if self.model_config.is_flux:
|
||||
arch = 'flux'
|
||||
if self.model_config.is_lumina2:
|
||||
arch = 'lumina2'
|
||||
sampler = get_sampler(
|
||||
self.train_config.noise_scheduler,
|
||||
{
|
||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||
},
|
||||
arch=arch,
|
||||
)
|
||||
ModelClass = get_model_class(self.model_config)
|
||||
# if the model class has get_train_scheduler static method
|
||||
if hasattr(ModelClass, 'get_train_scheduler'):
|
||||
sampler = ModelClass.get_train_scheduler()
|
||||
else:
|
||||
# get the noise scheduler
|
||||
arch = 'sd'
|
||||
if self.model_config.is_pixart:
|
||||
arch = 'pixart'
|
||||
if self.model_config.is_flux:
|
||||
arch = 'flux'
|
||||
if self.model_config.is_lumina2:
|
||||
arch = 'lumina2'
|
||||
sampler = get_sampler(
|
||||
self.train_config.noise_scheduler,
|
||||
{
|
||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||
},
|
||||
arch=arch,
|
||||
)
|
||||
|
||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||
@@ -1425,7 +1437,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
ModelClass = get_model_class(self.model_config)
|
||||
self.sd = ModelClass(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
@@ -1562,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if is_lycoris:
|
||||
# preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
|
||||
if hasattr(self.sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
@@ -1590,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=self.sd.is_transformer,
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
|
||||
transformers
|
||||
git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56
|
||||
transformers==4.49.0
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
pyyaml
|
||||
|
||||
@@ -29,7 +29,7 @@ def paramiter_count(model):
|
||||
return int(paramiter_count)
|
||||
|
||||
|
||||
def calculate_metrics(vae, images, max_imgs=-1):
|
||||
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
lpips_model = lpips.LPIPS(net='alex').to(device)
|
||||
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
||||
# ])
|
||||
# needs values between -1 and 1
|
||||
to_tensor = ToTensor()
|
||||
|
||||
# remove _reconstructed.png files
|
||||
images = [img for img in images if not img.endswith("_reconstructed.png")]
|
||||
|
||||
if max_imgs > 0 and len(images) > max_imgs:
|
||||
images = images[:max_imgs]
|
||||
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
||||
avg_rfid = 0
|
||||
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
||||
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
||||
|
||||
if save_output:
|
||||
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
folder = os.path.dirname(img_path)
|
||||
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
|
||||
reconstructed = (reconstructed + 1) / 2
|
||||
reconstructed = reconstructed.clamp(0, 1)
|
||||
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
|
||||
reconstructed.save(save_path)
|
||||
|
||||
return avg_rfid, avg_psnr, avg_lpips
|
||||
|
||||
@@ -91,18 +103,23 @@ def main():
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
||||
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
||||
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
||||
# boolean store true
|
||||
parser.add_argument("--save_output", action="store_true", help="Save the output images")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.isfile(args.vae_path):
|
||||
vae = AutoencoderKL.from_single_file(args.vae_path)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||
except:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.eval()
|
||||
vae = vae.to(device)
|
||||
print(f"Model has {paramiter_count(vae)} parameters")
|
||||
images = load_images(args.image_folder)
|
||||
|
||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
|
||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
|
||||
|
||||
# print(f"Average rFID: {avg_rfid}")
|
||||
print(f"Average PSNR: {avg_psnr}")
|
||||
|
||||
@@ -513,7 +513,7 @@ class ModelConfig:
|
||||
|
||||
self.te_name_or_path = kwargs.get("te_name_or_path", None)
|
||||
|
||||
self.arch: ModelArch = kwargs.get("model_arch", None)
|
||||
self.arch: ModelArch = kwargs.get("arch", None)
|
||||
|
||||
# handle migrating to new model arch
|
||||
if self.arch is None:
|
||||
|
||||
@@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
transformer_only: bool = False,
|
||||
peft_format: bool = False,
|
||||
is_assistant_adapter: bool = False,
|
||||
is_transformer: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
||||
|
||||
self.peft_format = peft_format
|
||||
self.is_transformer = is_transformer
|
||||
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
|
||||
# don't do peft format for lokr
|
||||
if self.network_type.lower() != "lokr":
|
||||
self.peft_format = True
|
||||
@@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if self.peft_format:
|
||||
unet_prefix = self.PEFT_PREFIX_UNET
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer:
|
||||
unet_prefix = f"lora_transformer"
|
||||
if self.peft_format:
|
||||
unet_prefix = "transformer"
|
||||
@@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.transformer_only and self.is_v3 and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
# handle custom models
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
|
||||
@@ -168,11 +168,17 @@ class BaseModel:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def unet(self):
|
||||
return self.model
|
||||
|
||||
# set unet to model
|
||||
@unet.setter
|
||||
def unet(self, value):
|
||||
self.model = value
|
||||
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
@@ -235,6 +241,7 @@ class BaseModel:
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -257,6 +264,25 @@ class BaseModel:
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
def get_model_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_model_has_grad must be implemented in child classes")
|
||||
|
||||
def get_te_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_te_has_grad must be implemented in child classes")
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# todo handle dtype without overloading anything (vram, cpu, etc)
|
||||
unwrap_model(self.pipeline).save_pretrained(
|
||||
save_directory=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
# end must be implemented in child classes
|
||||
|
||||
def te_train(self):
|
||||
@@ -512,6 +538,7 @@ class BaseModel:
|
||||
self.device_torch, dtype=self.unet.dtype)
|
||||
|
||||
img = self.generate_single_image(
|
||||
pipeline,
|
||||
gen_config,
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
@@ -603,7 +630,8 @@ class BaseModel:
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
timesteps: torch.IntTensor,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
original_samples_chunks = torch.chunk(
|
||||
original_samples, original_samples.shape[0], dim=0)
|
||||
@@ -1071,7 +1099,7 @@ class BaseModel:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
if self.is_flux or self.is_lumina2:
|
||||
if self.is_flux or self.is_lumina2 or self.is_transformer:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
@@ -1105,59 +1133,11 @@ class BaseModel:
|
||||
return named_params
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
version_string = 'sdxl'
|
||||
if self.is_ssd:
|
||||
# overwrite sdxl because both wil be true here
|
||||
version_string = 'ssd'
|
||||
if self.is_ssd and self.is_vega:
|
||||
version_string = 'vega'
|
||||
# if output file does not end in .safetensors, then it is a directory and we are
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
elif self.is_lumina2:
|
||||
# only save the unet
|
||||
transformer: Lumina2Transformer2DModel = unwrap_model(
|
||||
self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
else:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
self.save_model(
|
||||
output_path=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype
|
||||
)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
@@ -1240,12 +1220,7 @@ class BaseModel:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_lumina2:
|
||||
unet_has_grad = self.unet.x_embedder.weight.requires_grad
|
||||
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
unet_has_grad = self.get_model_has_grad()
|
||||
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
@@ -1262,13 +1237,7 @@ class BaseModel:
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
if isinstance(encoder, LlamaModel):
|
||||
te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
try:
|
||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
except:
|
||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
@@ -1276,17 +1245,7 @@ class BaseModel:
|
||||
'requires_grad': te_has_grad
|
||||
})
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
||||
te_has_grad = self.text_encoder.encoder.block[
|
||||
0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Gemma2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Qwen2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, LlamaModel):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
|
||||
458
toolkit/models/cogview4.py
Normal file
458
toolkit/models/cogview4.py
Normal file
@@ -0,0 +1,458 @@
|
||||
import weakref
|
||||
from diffusers import CogView4Pipeline
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
import os
|
||||
import copy
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
import torch
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.accelerator import unwrap_model
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# remove this after a bug is fixed in diffusers code. This is a workaround.
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, model):
|
||||
self.model_ref = weakref.ref(model)
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.model_ref().device
|
||||
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.25,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 0.75,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": None,
|
||||
"time_shift_type": "linear",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False
|
||||
}
|
||||
|
||||
|
||||
class CogView4(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(device, model_config, dtype,
|
||||
custom_pipeline, noise_scheduler, **kwargs)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['CogView4Transformer2DModel']
|
||||
|
||||
# cache for holding noise
|
||||
self.effective_noise = None
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
base_model_path = "THUDM/CogView4-6B"
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
# pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
||||
self.print_and_status_update("Loading CogView4 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
self.print_and_status_update("Loading GlmModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing GlmModel")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
# hack to fix diffusers bug workaround
|
||||
text_encoder.model = FakeModel(text_encoder)
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = CogView4Transformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for CogViewModels models")
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for CogViewModels models currently")
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for CogViewModels models currently")
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
scheduler = CogView4.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
pipe: CogView4Pipeline = CogView4Pipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
)
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
text_encoder.to(self.device_torch)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
self.pipeline = pipe
|
||||
self.model = transformer
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = CogView4.get_train_scheduler()
|
||||
pipeline = CogView4Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: CogView4Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional
|
||||
# they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine.
|
||||
if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = unconditional_embeds.text_embeds.shape[1] - \
|
||||
conditional_embeds.text_embeds.shape[1]
|
||||
conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len,
|
||||
conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1)
|
||||
elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = conditional_embeds.text_embeds.shape[1] - \
|
||||
unconditional_embeds.text_embeds.shape[1]
|
||||
unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len,
|
||||
unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1)
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
# target_size = (height, width)
|
||||
target_size = latent_model_input.shape[-2:]
|
||||
# multiply by 8
|
||||
target_size = (target_size[0] * 8, target_size[1] * 8)
|
||||
crops_coords_top_left = torch.tensor(
|
||||
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
|
||||
|
||||
original_size = torch.tensor(
|
||||
[target_size], dtype=self.torch_dtype, device=self.device_torch)
|
||||
target_size = original_size.clone()
|
||||
noise_pred_cond = self.model(
|
||||
hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128])
|
||||
encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096])
|
||||
timestep=timestep,
|
||||
original_size=original_size, # [[1024., 1024.]]
|
||||
target_size=target_size, # [[1024., 1024.]]
|
||||
crop_coords=crops_coords_top_left, # [[0., 0.]]
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return noise_pred_cond
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
prompt_embeds, _ = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
)
|
||||
return PromptEmbeds(prompt_embeds)
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return self.model.proj_out.weight.requires_grad
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
effective_noise = self.effective_noise
|
||||
batch = kwargs.get('batch')
|
||||
if batch is None:
|
||||
raise ValueError("Batch is not provided")
|
||||
if noise is None:
|
||||
raise ValueError("Noise is not provided")
|
||||
# return batch.latents
|
||||
return (noise - batch.latents).detach()
|
||||
# return (effective_noise - batch.latents).detach()
|
||||
|
||||
|
||||
def _get_low_res_latents(self, latents):
|
||||
# todo prevent needing to do this and grab the tensor another way.
|
||||
with torch.no_grad():
|
||||
# Decode latents to image space
|
||||
images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Downsample by a factor of 2 using bilinear interpolation
|
||||
B, C, H, W = images.shape
|
||||
low_res_images = torch.nn.functional.interpolate(
|
||||
images,
|
||||
size=(H // 4, W // 4),
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Upsample back to original resolution to match expected VAE input dimensions
|
||||
upsampled_low_res_images = torch.nn.functional.interpolate(
|
||||
low_res_images,
|
||||
size=(H, W),
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Encode the low-resolution images back to latent space
|
||||
low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
|
||||
return low_res_latents
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
# noise: torch.FloatTensor,
|
||||
# timesteps: torch.IntTensor,
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents_chunks = None
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
|
||||
# noisy_latents_chunks = []
|
||||
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
# # Flowmatching interpolation between original and noise
|
||||
# if t > relay_start_point:
|
||||
# # Standard flowmatching - direct linear interpolation
|
||||
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
|
||||
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
|
||||
# else:
|
||||
# # Relay flowmatching case - only compute low_res_latents if needed
|
||||
# if low_res_latents_chunks is None:
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
# # Calculate the relay ratio (0 to 1)
|
||||
# t_ratio = t.float() / relay_start_point
|
||||
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
|
||||
|
||||
# # First blend between original and low-res based on t_ratio
|
||||
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
|
||||
|
||||
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
# return noisy_latents
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
# noise: torch.FloatTensor,
|
||||
# timesteps: torch.IntTensor,
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
|
||||
# noisy_latents_chunks = []
|
||||
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
|
||||
# lrln = lrln * (1 - t_01)
|
||||
|
||||
# # make the noise an interpolation between noise and low_res_latents with
|
||||
# # being noise at t_01=1 and low_res_latents at t_01=0
|
||||
# # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
|
||||
# new_noise = noise_chunks[idx] + lrln
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(new_noise)
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
# return noisy_latents
|
||||
@@ -36,12 +36,11 @@ class Wan21(BaseModel):
|
||||
super().__init__(device, model_config, dtype,
|
||||
custom_pipeline, noise_scheduler, **kwargs)
|
||||
self.is_flow_matching = True
|
||||
raise NotImplementedError("Wan21 is not implemented yet")
|
||||
# these must be implemented in child classes
|
||||
|
||||
def load_model(self):
|
||||
self.pipeline = Wan21(
|
||||
|
||||
)
|
||||
pass
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
# override this in child classes
|
||||
@@ -50,6 +49,7 @@ class Wan21(BaseModel):
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -72,3 +72,11 @@ class Wan21(BaseModel):
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
def get_model_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_model_has_grad must be implemented in child classes")
|
||||
|
||||
def get_te_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_te_has_grad must be implemented in child classes")
|
||||
|
||||
@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# flatten second half to max
|
||||
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
hbsmntw_weighing[num_timesteps //
|
||||
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
step_indices = [(self.timesteps == t).nonzero().item()
|
||||
for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
if v2:
|
||||
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item()
|
||||
for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
# forward ODE
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
# reverse ODE
|
||||
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
|
||||
def set_train_timesteps(
|
||||
self,
|
||||
num_timesteps,
|
||||
device,
|
||||
timestep_type='linear',
|
||||
latents=None,
|
||||
patch_size=1
|
||||
):
|
||||
self.timestep_type = timestep_type
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
|
||||
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
|
||||
# matches inference dynamic shifting
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
|
||||
self.sigma_min), num_timesteps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
|
||||
if latents is None:
|
||||
raise ValueError('latents is None')
|
||||
|
||||
h = latents.shape[2] // 2 # Divide by ph
|
||||
w = latents.shape[3] // 2 # Divide by pw
|
||||
image_seq_len = h * w
|
||||
|
||||
# todo need to know the mu for the shift
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.config.get("base_image_seq_len", 256),
|
||||
self.config.get("max_image_seq_len", 4096),
|
||||
self.config.get("base_shift", 0.5),
|
||||
self.config.get("max_shift", 1.16),
|
||||
)
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
if self.config.use_dynamic_shifting:
|
||||
if latents is None:
|
||||
raise ValueError('latents is None')
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
# for flux we double up the patch size before sending her to simulate the latent reduction
|
||||
h = latents.shape[2]
|
||||
w = latents.shape[3]
|
||||
image_seq_len = h * w // (patch_size**2)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.config.get("base_image_seq_len", 256),
|
||||
self.config.get("max_image_seq_len", 4096),
|
||||
self.config.get("base_shift", 0.5),
|
||||
self.config.get("max_shift", 1.16),
|
||||
)
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(
|
||||
dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.invert_sigmas:
|
||||
sigmas = 1.0 - sigmas
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
sigmas = torch.cat(
|
||||
[sigmas, torch.ones(1, device=sigmas.device)])
|
||||
else:
|
||||
sigmas = torch.cat(
|
||||
[sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
return timesteps
|
||||
|
||||
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.75
|
||||
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
t1 = ((1 - t1/t1.max()) * 1000)
|
||||
|
||||
# add half of linear
|
||||
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
||||
t2 = torch.linspace(1000, 0, int(
|
||||
num_timesteps * (1 - alpha)), device=device)
|
||||
timesteps = torch.cat((t1, t2))
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
|
||||
@@ -160,7 +160,6 @@ class StableDiffusion:
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
@@ -205,6 +204,8 @@ class StableDiffusion:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
# todo update this based on the model
|
||||
self.is_transformer = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -246,6 +247,10 @@ class StableDiffusion:
|
||||
@property
|
||||
def is_lumina2(self):
|
||||
return self.arch == 'lumina2'
|
||||
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
return unwrap_model(self.unet)
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -977,7 +982,6 @@ class StableDiffusion:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
self.unet_unwrapped = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
@@ -1776,7 +1780,8 @@ class StableDiffusion:
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
timesteps: torch.IntTensor,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
|
||||
@@ -5,5 +5,8 @@ def get_model_class(config: ModelConfig):
|
||||
if config.arch == "wan21":
|
||||
from toolkit.models.wan21 import Wan21
|
||||
return Wan21
|
||||
elif config.arch == "cogview4":
|
||||
from toolkit.models.cogview4 import CogView4
|
||||
return CogView4
|
||||
else:
|
||||
return StableDiffusion
|
||||
Reference in New Issue
Block a user