mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
680 lines
25 KiB
Python
680 lines
25 KiB
Python
# WIP, coming soon ish
|
|
from functools import partial
|
|
import torch
|
|
import yaml
|
|
from toolkit.accelerator import unwrap_model
|
|
from toolkit.basic import flush
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.dequantize import patch_dequantization_on_save
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.paths import REPOS_ROOT
|
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
|
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
|
|
import os
|
|
import sys
|
|
|
|
import weakref
|
|
import torch
|
|
import yaml
|
|
|
|
from toolkit.basic import flush
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.dequantize import patch_dequantization_on_save
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
|
|
import os
|
|
import copy
|
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
|
|
import torch
|
|
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
|
from toolkit.util.quantize import quantize
|
|
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
|
from typing import TYPE_CHECKING, List
|
|
from toolkit.accelerator import unwrap_model
|
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
|
from torchvision.transforms import Resize, ToPILImage
|
|
from tqdm import tqdm
|
|
|
|
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
|
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
|
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original
|
|
|
|
# for generation only?
|
|
scheduler_configUniPC = {
|
|
"_class_name": "UniPCMultistepScheduler",
|
|
"_diffusers_version": "0.33.0.dev0",
|
|
"beta_end": 0.02,
|
|
"beta_schedule": "linear",
|
|
"beta_start": 0.0001,
|
|
"disable_corrector": [],
|
|
"dynamic_thresholding_ratio": 0.995,
|
|
"final_sigmas_type": "zero",
|
|
"flow_shift": 3.0,
|
|
"lower_order_final": True,
|
|
"num_train_timesteps": 1000,
|
|
"predict_x0": True,
|
|
"prediction_type": "flow_prediction",
|
|
"rescale_betas_zero_snr": False,
|
|
"sample_max_value": 1.0,
|
|
"solver_order": 2,
|
|
"solver_p": None,
|
|
"solver_type": "bh2",
|
|
"steps_offset": 0,
|
|
"thresholding": False,
|
|
"timestep_spacing": "linspace",
|
|
"trained_betas": None,
|
|
"use_beta_sigmas": False,
|
|
"use_exponential_sigmas": False,
|
|
"use_flow_sigmas": True,
|
|
"use_karras_sigmas": False
|
|
}
|
|
|
|
# for training. I think it is right
|
|
scheduler_config = {
|
|
"num_train_timesteps": 1000,
|
|
"shift": 3.0,
|
|
"use_dynamic_shifting": False
|
|
}
|
|
|
|
|
|
class AggressiveWanUnloadPipeline(WanPipeline):
|
|
def __init__(
|
|
self,
|
|
tokenizer: AutoTokenizer,
|
|
text_encoder: UMT5EncoderModel,
|
|
transformer: WanTransformer3DModel,
|
|
vae: AutoencoderKLWan,
|
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
|
device: torch.device = torch.device("cuda"),
|
|
):
|
|
super().__init__(
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
transformer=transformer,
|
|
vae=vae,
|
|
scheduler=scheduler,
|
|
)
|
|
self._exec_device = device
|
|
@property
|
|
def _execution_device(self):
|
|
return self._exec_device
|
|
|
|
def __call__(
|
|
self: WanPipeline,
|
|
prompt: Union[str, List[str]] = None,
|
|
negative_prompt: Union[str, List[str]] = None,
|
|
height: int = 480,
|
|
width: int = 832,
|
|
num_frames: int = 81,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 5.0,
|
|
num_videos_per_prompt: Optional[int] = 1,
|
|
generator: Optional[Union[torch.Generator,
|
|
List[torch.Generator]]] = None,
|
|
latents: Optional[torch.Tensor] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
output_type: Optional[str] = "np",
|
|
return_dict: bool = True,
|
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
callback_on_step_end: Optional[
|
|
Union[Callable[[int, int, Dict], None],
|
|
PipelineCallback, MultiPipelineCallbacks]
|
|
] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
max_sequence_length: int = 512,
|
|
):
|
|
|
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
|
|
|
# unload vae and transformer
|
|
vae_device = self.vae.device
|
|
transformer_device = self.transformer.device
|
|
text_encoder_device = self.text_encoder.device
|
|
device = self.transformer.device
|
|
|
|
print("Unloading vae")
|
|
self.vae.to("cpu")
|
|
self.text_encoder.to(device)
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(
|
|
prompt,
|
|
negative_prompt,
|
|
height,
|
|
width,
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
callback_on_step_end_tensor_inputs,
|
|
)
|
|
|
|
self._guidance_scale = guidance_scale
|
|
self._attention_kwargs = attention_kwargs
|
|
self._current_timestep = None
|
|
self._interrupt = False
|
|
|
|
# 2. Define call parameters
|
|
if prompt is not None and isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
# 3. Encode input prompt
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
)
|
|
|
|
# unload text encoder
|
|
print("Unloading text encoder")
|
|
self.text_encoder.to("cpu")
|
|
|
|
self.transformer.to(device)
|
|
|
|
transformer_dtype = self.transformer.dtype
|
|
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
|
|
if negative_prompt_embeds is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
device, transformer_dtype)
|
|
|
|
# 4. Prepare timesteps
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
timesteps = self.scheduler.timesteps
|
|
|
|
# 5. Prepare latent variables
|
|
num_channels_latents = self.transformer.config.in_channels
|
|
latents = self.prepare_latents(
|
|
batch_size * num_videos_per_prompt,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
num_frames,
|
|
torch.float32,
|
|
device,
|
|
generator,
|
|
latents,
|
|
)
|
|
|
|
# 6. Denoising loop
|
|
num_warmup_steps = len(timesteps) - \
|
|
num_inference_steps * self.scheduler.order
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if self.interrupt:
|
|
continue
|
|
|
|
self._current_timestep = t
|
|
latent_model_input = latents.to(device, transformer_dtype)
|
|
timestep = t.expand(latents.shape[0])
|
|
|
|
noise_pred = self.transformer(
|
|
hidden_states=latent_model_input,
|
|
timestep=timestep,
|
|
encoder_hidden_states=prompt_embeds,
|
|
attention_kwargs=attention_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
if self.do_classifier_free_guidance:
|
|
noise_uncond = self.transformer(
|
|
hidden_states=latent_model_input,
|
|
timestep=timestep,
|
|
encoder_hidden_states=negative_prompt_embeds,
|
|
attention_kwargs=attention_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
noise_pred = noise_uncond + guidance_scale * \
|
|
(noise_pred - noise_uncond)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents = self.scheduler.step(
|
|
noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(
|
|
self, i, t, callback_kwargs)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
prompt_embeds = callback_outputs.pop(
|
|
"prompt_embeds", prompt_embeds)
|
|
negative_prompt_embeds = callback_outputs.pop(
|
|
"negative_prompt_embeds", negative_prompt_embeds)
|
|
|
|
# call the callback, if provided
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
self._current_timestep = None
|
|
|
|
# unload transformer
|
|
# load vae
|
|
print("Loading Vae")
|
|
self.vae.to(vae_device)
|
|
|
|
if not output_type == "latent":
|
|
latents = latents.to(self.vae.dtype)
|
|
latents_mean = (
|
|
torch.tensor(self.vae.config.latents_mean)
|
|
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
|
.to(latents.device, latents.dtype)
|
|
)
|
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
|
latents.device, latents.dtype
|
|
)
|
|
latents = latents / latents_std + latents_mean
|
|
video = self.vae.decode(latents, return_dict=False)[0]
|
|
video = self.video_processor.postprocess_video(
|
|
video, output_type=output_type)
|
|
else:
|
|
video = latents
|
|
|
|
# Offload all models
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (video,)
|
|
|
|
return WanPipelineOutput(frames=video)
|
|
|
|
|
|
class Wan21(BaseModel):
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype='bf16',
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(device, model_config, dtype,
|
|
custom_pipeline, noise_scheduler, **kwargs)
|
|
self.is_flow_matching = True
|
|
self.is_transformer = True
|
|
self.target_lora_modules = ['WanTransformer3DModel']
|
|
|
|
# cache for holding noise
|
|
self.effective_noise = None
|
|
|
|
# static method to get the scheduler
|
|
@staticmethod
|
|
def get_train_scheduler():
|
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
return scheduler
|
|
|
|
def load_model(self):
|
|
dtype = self.torch_dtype
|
|
# todo , will this work with other wan models?
|
|
base_model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
|
model_path = self.model_config.name_or_path
|
|
|
|
self.print_and_status_update("Loading Wan2.1 model")
|
|
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
|
base_model_path = self.model_config.name_or_path_original
|
|
subfolder = 'transformer'
|
|
transformer_path = model_path
|
|
if os.path.exists(transformer_path):
|
|
subfolder = None
|
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
|
# check if the path is a full checkpoint.
|
|
te_folder_path = os.path.join(model_path, 'text_encoder')
|
|
# if we have the te, this folder is a full checkpoint, use it as the base
|
|
if os.path.exists(te_folder_path):
|
|
base_model_path = model_path
|
|
|
|
self.print_and_status_update("Loading transformer")
|
|
transformer = WanTransformer3DModel.from_pretrained(
|
|
transformer_path,
|
|
subfolder=subfolder,
|
|
torch_dtype=dtype,
|
|
).to(dtype=dtype)
|
|
|
|
if self.model_config.split_model_over_gpus:
|
|
raise ValueError(
|
|
"Splitting model over gpus is not supported for Wan2.1 models")
|
|
|
|
if not self.model_config.low_vram:
|
|
# quantize on the device
|
|
transformer.to(self.quantize_device, dtype=dtype)
|
|
flush()
|
|
|
|
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
|
raise ValueError(
|
|
"Assistant LoRA is not supported for Wan2.1 models currently")
|
|
|
|
if self.model_config.lora_path is not None:
|
|
raise ValueError(
|
|
"Loading LoRA is not supported for Wan2.1 models currently")
|
|
|
|
flush()
|
|
|
|
if self.model_config.quantize:
|
|
print("Quantizing Transformer")
|
|
quantization_args = self.model_config.quantize_kwargs
|
|
if 'exclude' not in quantization_args:
|
|
quantization_args['exclude'] = []
|
|
# patch the state dict method
|
|
patch_dequantization_on_save(transformer)
|
|
quantization_type = qfloat8
|
|
self.print_and_status_update("Quantizing transformer")
|
|
if self.model_config.low_vram:
|
|
print("Quantizing blocks")
|
|
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
|
# quantize each block
|
|
idx = 0
|
|
for block in tqdm(transformer.blocks):
|
|
block.to(self.device_torch)
|
|
quantize(block, weights=quantization_type,
|
|
**quantization_args)
|
|
freeze(block)
|
|
idx += 1
|
|
flush()
|
|
|
|
print("Quantizing the rest")
|
|
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
|
|
low_vram_exclude.append('blocks.*')
|
|
quantization_args['exclude'] = low_vram_exclude
|
|
# quantize the rest
|
|
transformer.to(self.device_torch)
|
|
quantize(transformer, weights=quantization_type,
|
|
**quantization_args)
|
|
|
|
quantization_args['exclude'] = orig_exclude
|
|
else:
|
|
# do it in one go
|
|
quantize(transformer, weights=quantization_type,
|
|
**quantization_args)
|
|
freeze(transformer)
|
|
# move it to the cpu for now
|
|
transformer.to("cpu")
|
|
else:
|
|
transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading UMT5EncoderModel")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
|
text_encoder = UMT5EncoderModel.from_pretrained(
|
|
base_model_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype)
|
|
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
flush()
|
|
|
|
if self.model_config.quantize_te:
|
|
self.print_and_status_update("Quantizing UMT5EncoderModel")
|
|
quantize(text_encoder, weights=qfloat8)
|
|
freeze(text_encoder)
|
|
flush()
|
|
|
|
if self.model_config.low_vram:
|
|
print("Moving transformer back to GPU")
|
|
# we can move it back to the gpu now
|
|
transformer.to(self.device_torch)
|
|
|
|
scheduler = Wan21.get_train_scheduler()
|
|
self.print_and_status_update("Loading VAE")
|
|
# todo, example does float 32? check if quality suffers
|
|
vae = AutoencoderKLWan.from_pretrained(
|
|
base_model_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
|
flush()
|
|
|
|
self.print_and_status_update("Making pipe")
|
|
pipe: WanPipeline = WanPipeline(
|
|
scheduler=scheduler,
|
|
text_encoder=None,
|
|
tokenizer=tokenizer,
|
|
vae=vae,
|
|
transformer=None,
|
|
)
|
|
pipe.text_encoder = text_encoder
|
|
pipe.transformer = transformer
|
|
|
|
self.print_and_status_update("Preparing Model")
|
|
|
|
text_encoder = pipe.text_encoder
|
|
tokenizer = pipe.tokenizer
|
|
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
|
|
flush()
|
|
text_encoder.to(self.device_torch)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
flush()
|
|
self.pipeline = pipe
|
|
self.model = transformer
|
|
self.vae = vae
|
|
self.text_encoder = text_encoder
|
|
self.tokenizer = tokenizer
|
|
|
|
def get_generation_pipeline(self):
|
|
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
|
|
if self.model_config.low_vram:
|
|
pipeline = AggressiveWanUnloadPipeline(
|
|
vae=self.vae,
|
|
transformer=self.model,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=scheduler,
|
|
device=self.device_torch
|
|
)
|
|
else:
|
|
pipeline = WanPipeline(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=scheduler,
|
|
)
|
|
|
|
pipeline = pipeline.to(self.device_torch)
|
|
|
|
return pipeline
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline: WanPipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
# reactivate progress bar since this is slooooow
|
|
pipeline.set_progress_bar_config(disable=False)
|
|
pipeline = pipeline.to(self.device_torch)
|
|
# todo, figure out how to do video
|
|
output = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype),
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype),
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
num_frames=gen_config.num_frames,
|
|
generator=generator,
|
|
return_dict=False,
|
|
output_type="pil",
|
|
**extra
|
|
)[0]
|
|
|
|
# shape = [1, frames, channels, height, width]
|
|
batch_item = output[0] # list of pil images
|
|
if gen_config.num_frames > 1:
|
|
return batch_item # return the frames.
|
|
else:
|
|
# get just the first image
|
|
img = batch_item[0]
|
|
return img
|
|
|
|
def get_noise_prediction(
|
|
self,
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
**kwargs
|
|
):
|
|
# vae_scale_factor_spatial = 8
|
|
# vae_scale_factor_temporal = 4
|
|
# num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
|
# shape = (
|
|
# batch_size,
|
|
# num_channels_latents, # 16
|
|
# num_latent_frames, # 81
|
|
# int(height) // self.vae_scale_factor_spatial,
|
|
# int(width) // self.vae_scale_factor_spatial,
|
|
# )
|
|
|
|
noise_pred = self.model(
|
|
hidden_states=latent_model_input,
|
|
timestep=timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds,
|
|
return_dict=False,
|
|
**kwargs
|
|
)[0]
|
|
return noise_pred
|
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
|
if self.pipeline.text_encoder.device != self.device_torch:
|
|
self.pipeline.text_encoder.to(self.device_torch)
|
|
prompt_embeds, _ = self.pipeline.encode_prompt(
|
|
prompt,
|
|
do_classifier_free_guidance=False,
|
|
max_sequence_length=512,
|
|
device=self.device_torch,
|
|
dtype=self.torch_dtype,
|
|
)
|
|
return PromptEmbeds(prompt_embeds)
|
|
|
|
@torch.no_grad()
|
|
def encode_images(
|
|
self,
|
|
image_list: List[torch.Tensor],
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
if device is None:
|
|
device = self.vae_device_torch
|
|
if dtype is None:
|
|
dtype = self.vae_torch_dtype
|
|
|
|
# Move to vae to device if on cpu
|
|
if self.vae.device == 'cpu':
|
|
self.vae.to(device)
|
|
self.vae.eval()
|
|
self.vae.requires_grad_(False)
|
|
# move to device and dtype
|
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
|
|
|
# We need to detect video if we have it.
|
|
# videos come in (num_frames, channels, height, width)
|
|
# images come in (channels, height, width)
|
|
# we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width)
|
|
|
|
if len(image_list[0].shape) == 3:
|
|
image_list = [image.unsqueeze(1) for image in image_list]
|
|
elif len(image_list[0].shape) == 4:
|
|
image_list = [image.permute(1, 0, 2, 3) for image in image_list]
|
|
else:
|
|
raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}")
|
|
|
|
VAE_SCALE_FACTOR = 8
|
|
|
|
# resize images if not divisible by 8
|
|
# now we need to resize considering the shape (channels, num_frames, height, width)
|
|
for i in range(len(image_list)):
|
|
image = image_list[i]
|
|
if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0:
|
|
# Create resized frames by handling each frame separately
|
|
c, f, h, w = image.shape
|
|
target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
|
target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
|
|
|
# We need to process each frame separately
|
|
resized_frames = []
|
|
for frame_idx in range(f):
|
|
frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width)
|
|
resized_frame = Resize((target_h, target_w))(frame)
|
|
resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back
|
|
|
|
# Concatenate all frames back together along the frame dimension
|
|
image_list[i] = torch.cat(resized_frames, dim=1)
|
|
|
|
images = torch.stack(image_list)
|
|
# images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w)
|
|
latents = self.vae.encode(images).latent_dist.sample()
|
|
|
|
latents_mean = (
|
|
torch.tensor(self.vae.config.latents_mean)
|
|
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
|
.to(latents.device, latents.dtype)
|
|
)
|
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
|
latents.device, latents.dtype
|
|
)
|
|
latents = (latents - latents_mean) * latents_std
|
|
|
|
latents = latents.to(device, dtype=dtype)
|
|
|
|
return latents
|
|
|
|
def get_model_has_grad(self):
|
|
return self.model.proj_out.weight.requires_grad
|
|
|
|
def get_te_has_grad(self):
|
|
return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
|
|
|
def save_model(self, output_path, meta, save_dtype):
|
|
# only save the unet
|
|
transformer: Wan21 = unwrap_model(self.model)
|
|
transformer.save_pretrained(
|
|
save_directory=os.path.join(output_path, 'transformer'),
|
|
safe_serialization=True,
|
|
)
|
|
|
|
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
|
with open(meta_path, 'w') as f:
|
|
yaml.dump(meta, f)
|
|
|
|
def get_loss_target(self, *args, **kwargs):
|
|
noise = kwargs.get('noise')
|
|
batch = kwargs.get('batch')
|
|
if batch is None:
|
|
raise ValueError("Batch is not provided")
|
|
if noise is None:
|
|
raise ValueError("Noise is not provided")
|
|
return (noise - batch.latents).detach()
|
|
|
|
def convert_lora_weights_before_save(self, state_dict):
|
|
return convert_to_original(state_dict)
|
|
|
|
def convert_lora_weights_before_load(self, state_dict):
|
|
return convert_to_diffusers(state_dict)
|