mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added cogview4. Loss still needs work.
This commit is contained in:
@@ -168,11 +168,17 @@ class BaseModel:
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def unet(self):
|
||||
return self.model
|
||||
|
||||
# set unet to model
|
||||
@unet.setter
|
||||
def unet(self, value):
|
||||
self.model = value
|
||||
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
@@ -235,6 +241,7 @@ class BaseModel:
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -257,6 +264,25 @@ class BaseModel:
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
def get_model_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_model_has_grad must be implemented in child classes")
|
||||
|
||||
def get_te_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_te_has_grad must be implemented in child classes")
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# todo handle dtype without overloading anything (vram, cpu, etc)
|
||||
unwrap_model(self.pipeline).save_pretrained(
|
||||
save_directory=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
# end must be implemented in child classes
|
||||
|
||||
def te_train(self):
|
||||
@@ -512,6 +538,7 @@ class BaseModel:
|
||||
self.device_torch, dtype=self.unet.dtype)
|
||||
|
||||
img = self.generate_single_image(
|
||||
pipeline,
|
||||
gen_config,
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
@@ -603,7 +630,8 @@ class BaseModel:
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
timesteps: torch.IntTensor,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
original_samples_chunks = torch.chunk(
|
||||
original_samples, original_samples.shape[0], dim=0)
|
||||
@@ -1071,7 +1099,7 @@ class BaseModel:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
if self.is_flux or self.is_lumina2:
|
||||
if self.is_flux or self.is_lumina2 or self.is_transformer:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
@@ -1105,59 +1133,11 @@ class BaseModel:
|
||||
return named_params
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
version_string = 'sdxl'
|
||||
if self.is_ssd:
|
||||
# overwrite sdxl because both wil be true here
|
||||
version_string = 'ssd'
|
||||
if self.is_ssd and self.is_vega:
|
||||
version_string = 'vega'
|
||||
# if output file does not end in .safetensors, then it is a directory and we are
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
elif self.is_lumina2:
|
||||
# only save the unet
|
||||
transformer: Lumina2Transformer2DModel = unwrap_model(
|
||||
self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
else:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
self.save_model(
|
||||
output_path=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype
|
||||
)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
@@ -1240,12 +1220,7 @@ class BaseModel:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_lumina2:
|
||||
unet_has_grad = self.unet.x_embedder.weight.requires_grad
|
||||
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
unet_has_grad = self.get_model_has_grad()
|
||||
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
@@ -1262,13 +1237,7 @@ class BaseModel:
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
if isinstance(encoder, LlamaModel):
|
||||
te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
try:
|
||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
except:
|
||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
@@ -1276,17 +1245,7 @@ class BaseModel:
|
||||
'requires_grad': te_has_grad
|
||||
})
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
||||
te_has_grad = self.text_encoder.encoder.block[
|
||||
0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Gemma2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Qwen2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, LlamaModel):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
te_has_grad = self.get_te_has_grad()
|
||||
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
|
||||
458
toolkit/models/cogview4.py
Normal file
458
toolkit/models/cogview4.py
Normal file
@@ -0,0 +1,458 @@
|
||||
import weakref
|
||||
from diffusers import CogView4Pipeline
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
import os
|
||||
import copy
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
||||
import torch
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from transformers import GlmModel, AutoTokenizer
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.accelerator import unwrap_model
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# remove this after a bug is fixed in diffusers code. This is a workaround.
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, model):
|
||||
self.model_ref = weakref.ref(model)
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.model_ref().device
|
||||
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.25,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 0.75,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": None,
|
||||
"time_shift_type": "linear",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False
|
||||
}
|
||||
|
||||
|
||||
class CogView4(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(device, model_config, dtype,
|
||||
custom_pipeline, noise_scheduler, **kwargs)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['CogView4Transformer2DModel']
|
||||
|
||||
# cache for holding noise
|
||||
self.effective_noise = None
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
base_model_path = "THUDM/CogView4-6B"
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
# pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
||||
self.print_and_status_update("Loading CogView4 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
self.print_and_status_update("Loading GlmModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing GlmModel")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
# hack to fix diffusers bug workaround
|
||||
text_encoder.model = FakeModel(text_encoder)
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = CogView4Transformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for CogViewModels models")
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for CogViewModels models currently")
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for CogViewModels models currently")
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
scheduler = CogView4.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
pipe: CogView4Pipeline = CogView4Pipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
)
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
text_encoder.to(self.device_torch)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
self.pipeline = pipe
|
||||
self.model = transformer
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = CogView4.get_train_scheduler()
|
||||
pipeline = CogView4Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: CogView4Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional
|
||||
# they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine.
|
||||
if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = unconditional_embeds.text_embeds.shape[1] - \
|
||||
conditional_embeds.text_embeds.shape[1]
|
||||
conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len,
|
||||
conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1)
|
||||
elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]:
|
||||
pad_len = conditional_embeds.text_embeds.shape[1] - \
|
||||
unconditional_embeds.text_embeds.shape[1]
|
||||
unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len,
|
||||
unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1)
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
# target_size = (height, width)
|
||||
target_size = latent_model_input.shape[-2:]
|
||||
# multiply by 8
|
||||
target_size = (target_size[0] * 8, target_size[1] * 8)
|
||||
crops_coords_top_left = torch.tensor(
|
||||
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch)
|
||||
|
||||
original_size = torch.tensor(
|
||||
[target_size], dtype=self.torch_dtype, device=self.device_torch)
|
||||
target_size = original_size.clone()
|
||||
noise_pred_cond = self.model(
|
||||
hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128])
|
||||
encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096])
|
||||
timestep=timestep,
|
||||
original_size=original_size, # [[1024., 1024.]]
|
||||
target_size=target_size, # [[1024., 1024.]]
|
||||
crop_coords=crops_coords_top_left, # [[0., 0.]]
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return noise_pred_cond
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
prompt_embeds, _ = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
)
|
||||
return PromptEmbeds(prompt_embeds)
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return self.model.proj_out.weight.requires_grad
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
transformer: CogView4Transformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
effective_noise = self.effective_noise
|
||||
batch = kwargs.get('batch')
|
||||
if batch is None:
|
||||
raise ValueError("Batch is not provided")
|
||||
if noise is None:
|
||||
raise ValueError("Noise is not provided")
|
||||
# return batch.latents
|
||||
return (noise - batch.latents).detach()
|
||||
# return (effective_noise - batch.latents).detach()
|
||||
|
||||
|
||||
def _get_low_res_latents(self, latents):
|
||||
# todo prevent needing to do this and grab the tensor another way.
|
||||
with torch.no_grad():
|
||||
# Decode latents to image space
|
||||
images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Downsample by a factor of 2 using bilinear interpolation
|
||||
B, C, H, W = images.shape
|
||||
low_res_images = torch.nn.functional.interpolate(
|
||||
images,
|
||||
size=(H // 4, W // 4),
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Upsample back to original resolution to match expected VAE input dimensions
|
||||
upsampled_low_res_images = torch.nn.functional.interpolate(
|
||||
low_res_images,
|
||||
size=(H, W),
|
||||
mode="bilinear",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Encode the low-resolution images back to latent space
|
||||
low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype)
|
||||
return low_res_latents
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
# noise: torch.FloatTensor,
|
||||
# timesteps: torch.IntTensor,
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents_chunks = None
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
|
||||
# noisy_latents_chunks = []
|
||||
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
# # Flowmatching interpolation between original and noise
|
||||
# if t > relay_start_point:
|
||||
# # Standard flowmatching - direct linear interpolation
|
||||
# noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx]
|
||||
# effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise
|
||||
# else:
|
||||
# # Relay flowmatching case - only compute low_res_latents if needed
|
||||
# if low_res_latents_chunks is None:
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
# # Calculate the relay ratio (0 to 1)
|
||||
# t_ratio = t.float() / relay_start_point
|
||||
# t_ratio = torch.clamp(t_ratio, 0.0, 1.0)
|
||||
|
||||
# # First blend between original and low-res based on t_ratio
|
||||
# z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx]
|
||||
|
||||
# added_lor_res_noise = z0_t - original_samples_chunks[idx]
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx]
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise)
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
# return noisy_latents
|
||||
|
||||
|
||||
# def add_noise(
|
||||
# self,
|
||||
# original_samples: torch.FloatTensor,
|
||||
# noise: torch.FloatTensor,
|
||||
# timesteps: torch.IntTensor,
|
||||
# **kwargs,
|
||||
# ) -> torch.FloatTensor:
|
||||
# relay_start_point = 500
|
||||
|
||||
# # Store original samples for loss calculation
|
||||
# self.original_samples = original_samples
|
||||
|
||||
# # Prepare chunks for batch processing
|
||||
# original_samples_chunks = torch.chunk(
|
||||
# original_samples, original_samples.shape[0], dim=0)
|
||||
# noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
# timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
|
||||
# # Get the low res latents only if needed
|
||||
# low_res_latents = self._get_low_res_latents(original_samples)
|
||||
# low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0)
|
||||
|
||||
# # Handle case where timesteps is a single value for all samples
|
||||
# if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
||||
# timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
||||
|
||||
# noisy_latents_chunks = []
|
||||
# effective_noise_chunks = [] # Store the effective noise for each sample
|
||||
|
||||
# for idx in range(original_samples.shape[0]):
|
||||
# t = timesteps_chunks[idx]
|
||||
# t_01 = (t / 1000).to(original_samples_chunks[idx].device)
|
||||
|
||||
# lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx]
|
||||
# lrln = lrln * (1 - t_01)
|
||||
|
||||
# # make the noise an interpolation between noise and low_res_latents with
|
||||
# # being noise at t_01=1 and low_res_latents at t_01=0
|
||||
# # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln
|
||||
# new_noise = noise_chunks[idx] + lrln
|
||||
|
||||
# # Then apply flowmatching interpolation between this blended state and noise
|
||||
# noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise
|
||||
|
||||
# # For prediction target, we need to store the effective "source"
|
||||
# effective_noise_chunks.append(new_noise)
|
||||
|
||||
# noisy_latents_chunks.append(noisy_latents)
|
||||
|
||||
# noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||
# self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation
|
||||
|
||||
# return noisy_latents
|
||||
@@ -36,12 +36,11 @@ class Wan21(BaseModel):
|
||||
super().__init__(device, model_config, dtype,
|
||||
custom_pipeline, noise_scheduler, **kwargs)
|
||||
self.is_flow_matching = True
|
||||
raise NotImplementedError("Wan21 is not implemented yet")
|
||||
# these must be implemented in child classes
|
||||
|
||||
def load_model(self):
|
||||
self.pipeline = Wan21(
|
||||
|
||||
)
|
||||
pass
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
# override this in child classes
|
||||
@@ -50,6 +49,7 @@ class Wan21(BaseModel):
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -72,3 +72,11 @@ class Wan21(BaseModel):
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
raise NotImplementedError(
|
||||
"get_prompt_embeds must be implemented in child classes")
|
||||
|
||||
def get_model_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_model_has_grad must be implemented in child classes")
|
||||
|
||||
def get_te_has_grad(self):
|
||||
raise NotImplementedError(
|
||||
"get_te_has_grad must be implemented in child classes")
|
||||
|
||||
Reference in New Issue
Block a user