From ca7c5c950b11e15e4cc139879f9f5102b9839d4b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 29 Jul 2025 05:31:54 -0600 Subject: [PATCH] Add support for Wan2.2 5B --- .../diffusion_models/__init__.py | 14 +- .../diffusion_models/wan22/__init__.py | 1 + .../diffusion_models/wan22/wan22_model.py | 259 +++++++ .../diffusion_models/wan22/wan22_pipeline.py | 263 +++++++ requirements.txt | 2 +- toolkit/models/wan21/autoencoder_kl_wan.py | 692 ++++++++++++++++-- toolkit/models/wan21/wan21.py | 17 +- toolkit/models/wan21/wan21_i2v.py | 8 +- toolkit/models/wan21/wan_utils.py | 54 +- ui/src/app/jobs/new/options.ts | 21 + version.py | 2 +- 11 files changed, 1241 insertions(+), 92 deletions(-) create mode 100644 extensions_built_in/diffusion_models/wan22/__init__.py create mode 100644 extensions_built_in/diffusion_models/wan22/wan22_model.py create mode 100644 extensions_built_in/diffusion_models/wan22/wan22_pipeline.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 54bd470c..b498e263 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -3,13 +3,15 @@ from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel +from .wan22 import Wan22Model AI_TOOLKIT_MODELS = [ # put a list of models here - ChromaModel, - HidreamModel, - HidreamE1Model, - FLiteModel, - OmniGen2Model, - FluxKontextModel + ChromaModel, + HidreamModel, + HidreamE1Model, + FLiteModel, + OmniGen2Model, + FluxKontextModel, + Wan22Model, ] diff --git a/extensions_built_in/diffusion_models/wan22/__init__.py b/extensions_built_in/diffusion_models/wan22/__init__.py new file mode 100644 index 00000000..765b1a18 --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/__init__.py @@ -0,0 +1 @@ +from .wan22_model import Wan22Model \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/wan22/wan22_model.py b/extensions_built_in/diffusion_models/wan22/wan22_model.py new file mode 100644 index 00000000..565a3fea --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/wan22_model.py @@ -0,0 +1,259 @@ +import torch +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +from diffusers import UniPCMultistepScheduler +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from .wan22_pipeline import Wan22Pipeline + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from torchvision.transforms import functional as TF + +from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22 + + +# for generation only? +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.35.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 5.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "time_shift_type": "exponential", + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, +} + +# for training. I think it is right +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 5.0, + "use_dynamic_shifting": False, +} + + +class Wan22Model(Wan21): + arch = "wan22_5b" + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = True + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device=device, + model_config=model_config, + dtype=dtype, + custom_pipeline=custom_pipeline, + noise_scheduler=noise_scheduler, + **kwargs, + ) + + self._wan_cache = None + + def get_bucket_divisibility(self): + # 16x compression and 2x2 patch size + return 32 + + def get_generation_pipeline(self): + scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + pipeline = Wan22Pipeline( + vae=self.vae, + transformer=self.model, + transformer_2=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch, + aggressive_offload=self.model_config.low_vram, + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def get_base_model_version(self): + return "wan_2.2_5b" + + def generate_single_image( + self, + pipeline: AggressiveWanUnloadPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + + num_frames = ( + (gen_config.num_frames - 1) // 4 + ) * 4 + 1 # make sure it is divisible by 4 + 1 + gen_config.num_frames = num_frames + + height = gen_config.height + width = gen_config.width + noise_mask = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + d = self.get_bucket_divisibility() + + # make sure they are divisible by d + height = height // d * d + width = width // d * d + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = pipeline.prepare_latents( + 1, + num_channels_latents, + height, + width, + gen_config.num_frames, + torch.float32, + self.device_torch, + generator, + None, + ).to(self.torch_dtype) + + first_frame_n1p1 = ( + TF.to_tensor(control_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + * 2.0 + - 1.0 + ) # normalize to [-1, 1] + + gen_config.latents, noise_mask = add_first_frame_conditioning_v22( + latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae + ) + + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + noise_mask=noise_mask, + **extra, + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs, + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + + # for wan, only do i2v for video for now. Images do normal t2i + conditioned_latent = latent_model_input + noise_mask = None + + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + # Add conditioning using the standalone function + conditioned_latent, noise_mask = add_first_frame_conditioning_v22( + latent_model_input=latent_model_input.to( + self.device_torch, self.torch_dtype + ), + first_frame=first_frames.to(self.device_torch, self.torch_dtype), + vae=self.vae, + ) + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # make the noise mask + if noise_mask is None: + noise_mask = torch.ones( + conditioned_latent.shape, + dtype=conditioned_latent.dtype, + device=conditioned_latent.device, + ) + # todo write this better + t_chunks = torch.chunk(timestep, timestep.shape[0]) + out_t_chunks = [] + for t in t_chunks: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + temp_ts = temp_ts.unsqueeze(0) + out_t_chunks.append(temp_ts) + timestep = torch.cat(out_t_chunks, dim=0) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + return_dict=False, + **kwargs, + )[0] + return noise_pred diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py new file mode 100644 index 00000000..442b5c7a --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -0,0 +1,263 @@ + +import torch +from toolkit.basic import flush +from transformers import AutoTokenizer, UMT5EncoderModel +from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import List +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union + + + +class Wan22Pipeline(WanPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + device: torch.device = torch.device("cuda"), + aggressive_offload: bool = False, + ): + super().__init__( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + vae=vae, + scheduler=scheduler, + ) + self._aggressive_offload = aggressive_offload + self._exec_device = device + @property + def _execution_device(self): + return self._exec_device + + def __call__( + self: WanPipeline, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], + PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + noise_mask: Optional[torch.Tensor] = None, + ): + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # unload vae and transformer + vae_device = self.vae.device + transformer_device = self.transformer.device + text_encoder_device = self.text_encoder.device + device = self.transformer.device + + if self._aggressive_offload: + print("Unloading vae") + self.vae.to("cpu") + self.text_encoder.to(device) + flush() + + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if self._aggressive_offload: + # unload text encoder + print("Unloading text encoder") + self.text_encoder.to("cpu") + self.transformer.to(device) + flush() + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(device, transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + device, transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = noise_mask + if mask is None: + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(device, transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * \ + (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + # apply i2v mask + latents = (latent_model_input * (1 - mask)) + ( + latents * mask + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if self._aggressive_offload: + # unload transformer + print("Unloading transformer") + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + # load vae + print("Loading Vae") + self.vae.to(vae_device) + flush() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video( + video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/requirements.txt b/requirements.txt index 65cc41c3..15a086ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76 +git+https://github.com/huggingface/diffusers@56d438727036b0918b30bbe3110c5fe1634ed19d transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/models/wan21/autoencoder_kl_wan.py b/toolkit/models/wan21/autoencoder_kl_wan.py index 4f5b6ebd..3a966108 100644 --- a/toolkit/models/wan21/autoencoder_kl_wan.py +++ b/toolkit/models/wan21/autoencoder_kl_wan.py @@ -20,6 +20,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin from diffusers.utils import logging from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.models.activations import get_activation @@ -34,6 +35,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name CACHE_T = 2 +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + class WanCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -134,19 +233,23 @@ class WanResample(nn.Module): - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. """ - def __init__(self, dim: int, mode: str) -> None: + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: super().__init__() self.dim = dim self.mode = mode + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + # layers if mode == "upsample2d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) elif mode == "upsample3d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -363,6 +466,48 @@ class WanMidBlock(nn.Module): return x +class WanResidualDownBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache, feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + class WanEncoder3d(nn.Module): r""" A 3D encoder module. @@ -380,6 +525,7 @@ class WanEncoder3d(nn.Module): def __init__( self, + in_channels: int = 3, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], @@ -388,6 +534,7 @@ class WanEncoder3d(nn.Module): temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock ): super().__init__() self.dim = dim @@ -403,23 +550,35 @@ class WanEncoder3d(nn.Module): scale = 1.0 # init block - self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1) + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) # downsample blocks self.down_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - for _ in range(num_res_blocks): - self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - self.down_blocks.append(WanAttentionBlock(out_dim)) - in_dim = out_dim + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim - # downsample block - if i != len(dim_mult) - 1: - mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - self.down_blocks.append(WanResample(out_dim, mode=mode)) - scale /= 2.0 + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 # middle blocks self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) @@ -469,6 +628,92 @@ class WanEncoder3d(nn.Module): x = self.conv_out(x) return x +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache, feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x class WanUpBlock(nn.Module): """ @@ -513,7 +758,7 @@ class WanUpBlock(nn.Module): self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): """ Forward pass through the upsampling block. @@ -564,6 +809,8 @@ class WanDecoder3d(nn.Module): temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, ): super().__init__() self.dim = dim @@ -577,7 +824,6 @@ class WanDecoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) @@ -589,36 +835,47 @@ class WanDecoder3d(nn.Module): self.up_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - if i > 0: + if i > 0 and not is_residual: + # wan vae 2.1 in_dim = in_dim // 2 - # Determine if we need upsampling + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None upsample_mode = None - if i != len(dim_mult) - 1: - upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" # Create and add the upsampling block - up_block = WanUpBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks, - dropout=dropout, - upsample_mode=upsample_mode, - non_linearity=non_linearity, - ) + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag= up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) self.up_blocks.append(up_block) - # Update scale for next iteration - if upsample_mode is not None: - scale *= 2.0 - # output blocks self.norm_out = WanRMS_norm(out_dim, images=False) - self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): ## conv1 if feat_cache is not None: idx = feat_idx[0] @@ -633,20 +890,11 @@ class WanDecoder3d(nn.Module): x = self.conv_in(x) ## middle - if torch.is_grad_enabled() and self.gradient_checkpointing: - # middle - x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) - - ## upsamples - for up_block in self.up_blocks: - x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx) - - else: - x = self.mid_block(x, feat_cache, feat_idx) + x = self.mid_block(x, feat_cache, feat_idx) - ## upsamples - for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) ## head x = self.norm_out(x) @@ -665,7 +913,46 @@ class WanDecoder3d(nn.Module): return x -class AutoencoderKLWan(ModelMixin, ConfigMixin): +def patchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + +class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [Wan 2.1]. @@ -674,12 +961,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): for all models (such as downloading or saving). """ - _supports_gradient_checkpointing = True + _supports_gradient_checkpointing = False @register_to_config def __init__( self, base_dim: int = 96, + decoder_base_dim: Optional[int] = None, z_dim: int = 16, dim_mult: Tuple[int] = [1, 2, 4, 4], num_res_blocks: int = 2, @@ -722,6 +1010,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): 2.8251, 1.9160, ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: Optional[int] = None, + scale_factor_temporal: Optional[int] = 4, + scale_factor_spatial: Optional[int] = 8, + clip_output: bool = True, ) -> None: super().__init__() @@ -729,37 +1024,119 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] + if decoder_base_dim is None: + decoder_base_dim = base_dim + self.encoder = WanEncoder3d( - base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual ) self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) self.decoder = WanDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual ) - def clear_cache(self): - def _count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, WanCausalConv3d): - count += 1 - return count + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) - self._conv_num = _count_conv3d(self.decoder) + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] self._conv_idx = [0] self._feat_map = [None] * self._conv_num # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_num = self._cached_conv_counts["encoder"] self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num - def _encode(self, x: torch.Tensor) -> torch.Tensor: + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + self.clear_cache() - ## cache - t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + iter_ = 1 + (num_frame - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] if i == 0: @@ -773,8 +1150,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): out = torch.cat([out, out_], 2) enc = self.quant_conv(out) - mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] - enc = torch.cat([mu, logvar], dim=1) self.clear_cache() return enc @@ -794,27 +1169,39 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self._encode(x) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) posterior = DiagonalGaussianDistribution(h) + if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - self.clear_cache() + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio - iter_ = z.shape[2] + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() x = self.post_quant_conv(z) - for i in range(iter_): - + for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) - out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.clip_output: + out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) self.clear_cache() if not return_dict: return (out,) @@ -836,12 +1223,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - decoded = self._decode(z).sample + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + if not return_dict: return (decoded,) - return DecoderOutput(sample=decoded) + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, @@ -862,4 +1398,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin): else: z = posterior.mode() dec = self.decode(z, return_dict=return_dict) - return dec + return dec \ No newline at end of file diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 9f029c7b..7146630d 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -89,12 +89,18 @@ class AggressiveWanUnloadPipeline(WanPipeline): transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v device: torch.device = torch.device("cuda"), ): super().__init__( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, vae=vae, scheduler=scheduler, ) @@ -300,6 +306,8 @@ class AggressiveWanUnloadPipeline(WanPipeline): class Wan21(BaseModel): arch = 'wan21' + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = False def __init__( self, device, @@ -331,7 +339,7 @@ class Wan21(BaseModel): dtype = self.torch_dtype model_path = self.model_config.name_or_path - self.print_and_status_update("Loading Wan2.1 model") + self.print_and_status_update("Loading Wan model") subfolder = 'transformer' transformer_path = model_path if os.path.exists(transformer_path): @@ -380,7 +388,6 @@ class Wan21(BaseModel): # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = get_qtype(self.model_config.qtype) - self.print_and_status_update("Quantizing transformer") if self.model_config.low_vram: print("Quantizing blocks") orig_exclude = copy.deepcopy(quantization_args['exclude']) @@ -474,22 +481,26 @@ class Wan21(BaseModel): self.tokenizer = tokenizer def get_generation_pipeline(self): - scheduler = UniPCMultistepScheduler(**scheduler_configUniPC) + scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) if self.model_config.low_vram: pipeline = AggressiveWanUnloadPipeline( vae=self.vae, transformer=self.model, + transformer_2=self.model, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, device=self.device_torch ) else: pipeline = WanPipeline( vae=self.vae, transformer=self.unet, + transformer_2=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, + expand_timesteps=self._wan_expand_timesteps, scheduler=scheduler, ) diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index 8b9f2918..bf5a88b8 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -48,11 +48,13 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer_2: WanTransformer3DModel = None, + boundary_ratio: Optional[float] = None, device: torch.device = torch.device("cuda"), ): super().__init__( @@ -63,6 +65,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): transformer=transformer, scheduler=scheduler, image_processor=image_processor, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, ) self._exec_device = device diff --git a/toolkit/models/wan21/wan_utils.py b/toolkit/models/wan21/wan_utils.py index 422cf3f7..3a300d2b 100644 --- a/toolkit/models/wan21/wan_utils.py +++ b/toolkit/models/wan21/wan_utils.py @@ -39,7 +39,7 @@ def add_first_frame_conditioning( first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) # resize first frame to match the latent model input - vae_scale_factor = 8 + vae_scale_factor = vae.config.scale_factor_spatial first_frame = F.interpolate( first_frame, size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), @@ -111,3 +111,55 @@ def add_first_frame_conditioning( [latent_model_input, first_frame_condition], dim=1) return conditioned_latent + + +def add_first_frame_conditioning_v22( + latent_model_input, + first_frame, + vae +): + """ + Overwrites first few time steps in latent_model_input with VAE-encoded first_frame, + and returns the modified latent + binary mask (0=conditioned, 1=noise). + + Args: + latent_model_input: torch.Tensor of shape (bs, 48, T, H, W) + first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale) + vae: VAE model with .encode() and .config.latents_mean/std + + Returns: + latent: (bs, 48, T, H, W) - modified input latent + mask: (bs, 1, T, H, W) - binary mask + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + bs, _, T, H, W = latent_model_input.shape + scale = vae.config.scale_factor_spatial + target_h = H * scale + target_w = W * scale + + # Ensure shape + if first_frame.ndim == 3: + first_frame = first_frame.unsqueeze(0) + if first_frame.shape[0] != bs: + first_frame = first_frame.expand(bs, -1, -1, -1) + + # Resize and encode + first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) + first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W) + encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device) + + # Normalize + mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype) + std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype) + encoded = (encoded - mean) * std + + # Replace in latent + latent = latent_model_input.clone() + latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0] + + # Mask: 0 where conditioned, 1 otherwise + mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype) + mask[:, :, :encoded.shape[2]] = 0.0 + + return latent, mask \ No newline at end of file diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 17101cbb..6fcfc8f6 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -181,6 +181,27 @@ export const modelArchs: ModelArch[] = [ disableSections: ['network.conv'], additionalSections: ['datasets.num_frames', 'model.low_vram'], }, + { + name: 'wan22_5b', + label: 'Wan 2.2 TI2V (5B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [121, 1], + 'config.process[0].sample.fps': [24, 1], + 'config.process[0].sample.width': [768, 1024], + 'config.process[0].sample.height': [768, 1024], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], + }, { name: 'lumina2', label: 'Lumina2', diff --git a/version.py b/version.py index b654c117..2abb236f 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.13" \ No newline at end of file +VERSION = "0.3.14" \ No newline at end of file