mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add support for Wan2.2 5B
This commit is contained in:
@@ -3,13 +3,15 @@ from .hidream import HidreamModel, HidreamE1Model
|
|||||||
from .f_light import FLiteModel
|
from .f_light import FLiteModel
|
||||||
from .omnigen2 import OmniGen2Model
|
from .omnigen2 import OmniGen2Model
|
||||||
from .flux_kontext import FluxKontextModel
|
from .flux_kontext import FluxKontextModel
|
||||||
|
from .wan22 import Wan22Model
|
||||||
|
|
||||||
AI_TOOLKIT_MODELS = [
|
AI_TOOLKIT_MODELS = [
|
||||||
# put a list of models here
|
# put a list of models here
|
||||||
ChromaModel,
|
ChromaModel,
|
||||||
HidreamModel,
|
HidreamModel,
|
||||||
HidreamE1Model,
|
HidreamE1Model,
|
||||||
FLiteModel,
|
FLiteModel,
|
||||||
OmniGen2Model,
|
OmniGen2Model,
|
||||||
FluxKontextModel
|
FluxKontextModel,
|
||||||
|
Wan22Model,
|
||||||
]
|
]
|
||||||
|
|||||||
1
extensions_built_in/diffusion_models/wan22/__init__.py
Normal file
1
extensions_built_in/diffusion_models/wan22/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .wan22_model import Wan22Model
|
||||||
259
extensions_built_in/diffusion_models/wan22/wan22_model.py
Normal file
259
extensions_built_in/diffusion_models/wan22/wan22_model.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import torch
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import UniPCMultistepScheduler
|
||||||
|
import torch
|
||||||
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||||
|
CustomFlowMatchEulerDiscreteScheduler,
|
||||||
|
)
|
||||||
|
from .wan22_pipeline import Wan22Pipeline
|
||||||
|
|
||||||
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
|
from torchvision.transforms import functional as TF
|
||||||
|
|
||||||
|
from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline
|
||||||
|
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22
|
||||||
|
|
||||||
|
|
||||||
|
# for generation only?
|
||||||
|
scheduler_configUniPC = {
|
||||||
|
"_class_name": "UniPCMultistepScheduler",
|
||||||
|
"_diffusers_version": "0.35.0.dev0",
|
||||||
|
"beta_end": 0.02,
|
||||||
|
"beta_schedule": "linear",
|
||||||
|
"beta_start": 0.0001,
|
||||||
|
"disable_corrector": [],
|
||||||
|
"dynamic_thresholding_ratio": 0.995,
|
||||||
|
"final_sigmas_type": "zero",
|
||||||
|
"flow_shift": 5.0,
|
||||||
|
"lower_order_final": True,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"predict_x0": True,
|
||||||
|
"prediction_type": "flow_prediction",
|
||||||
|
"rescale_betas_zero_snr": False,
|
||||||
|
"sample_max_value": 1.0,
|
||||||
|
"solver_order": 2,
|
||||||
|
"solver_p": None,
|
||||||
|
"solver_type": "bh2",
|
||||||
|
"steps_offset": 0,
|
||||||
|
"thresholding": False,
|
||||||
|
"time_shift_type": "exponential",
|
||||||
|
"timestep_spacing": "linspace",
|
||||||
|
"trained_betas": None,
|
||||||
|
"use_beta_sigmas": False,
|
||||||
|
"use_dynamic_shifting": False,
|
||||||
|
"use_exponential_sigmas": False,
|
||||||
|
"use_flow_sigmas": True,
|
||||||
|
"use_karras_sigmas": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# for training. I think it is right
|
||||||
|
scheduler_config = {
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"shift": 5.0,
|
||||||
|
"use_dynamic_shifting": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22Model(Wan21):
|
||||||
|
arch = "wan22_5b"
|
||||||
|
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||||
|
_wan_expand_timesteps = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
dtype="bf16",
|
||||||
|
custom_pipeline=None,
|
||||||
|
noise_scheduler=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
model_config=model_config,
|
||||||
|
dtype=dtype,
|
||||||
|
custom_pipeline=custom_pipeline,
|
||||||
|
noise_scheduler=noise_scheduler,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._wan_cache = None
|
||||||
|
|
||||||
|
def get_bucket_divisibility(self):
|
||||||
|
# 16x compression and 2x2 patch size
|
||||||
|
return 32
|
||||||
|
|
||||||
|
def get_generation_pipeline(self):
|
||||||
|
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||||
|
pipeline = Wan22Pipeline(
|
||||||
|
vae=self.vae,
|
||||||
|
transformer=self.model,
|
||||||
|
transformer_2=self.model,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
expand_timesteps=self._wan_expand_timesteps,
|
||||||
|
device=self.device_torch,
|
||||||
|
aggressive_offload=self.model_config.low_vram,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = pipeline.to(self.device_torch)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
# static method to get the scheduler
|
||||||
|
@staticmethod
|
||||||
|
def get_train_scheduler():
|
||||||
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def get_base_model_version(self):
|
||||||
|
return "wan_2.2_5b"
|
||||||
|
|
||||||
|
def generate_single_image(
|
||||||
|
self,
|
||||||
|
pipeline: AggressiveWanUnloadPipeline,
|
||||||
|
gen_config: GenerateImageConfig,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
unconditional_embeds: PromptEmbeds,
|
||||||
|
generator: torch.Generator,
|
||||||
|
extra: dict,
|
||||||
|
):
|
||||||
|
# reactivate progress bar since this is slooooow
|
||||||
|
pipeline.set_progress_bar_config(disable=False)
|
||||||
|
|
||||||
|
num_frames = (
|
||||||
|
(gen_config.num_frames - 1) // 4
|
||||||
|
) * 4 + 1 # make sure it is divisible by 4 + 1
|
||||||
|
gen_config.num_frames = num_frames
|
||||||
|
|
||||||
|
height = gen_config.height
|
||||||
|
width = gen_config.width
|
||||||
|
noise_mask = None
|
||||||
|
if gen_config.ctrl_img is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||||
|
|
||||||
|
d = self.get_bucket_divisibility()
|
||||||
|
|
||||||
|
# make sure they are divisible by d
|
||||||
|
height = height // d * d
|
||||||
|
width = width // d * d
|
||||||
|
|
||||||
|
# resize the control image
|
||||||
|
control_img = control_img.resize((width, height), Image.LANCZOS)
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels
|
||||||
|
latents = pipeline.prepare_latents(
|
||||||
|
1,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
gen_config.num_frames,
|
||||||
|
torch.float32,
|
||||||
|
self.device_torch,
|
||||||
|
generator,
|
||||||
|
None,
|
||||||
|
).to(self.torch_dtype)
|
||||||
|
|
||||||
|
first_frame_n1p1 = (
|
||||||
|
TF.to_tensor(control_img)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
|
* 2.0
|
||||||
|
- 1.0
|
||||||
|
) # normalize to [-1, 1]
|
||||||
|
|
||||||
|
gen_config.latents, noise_mask = add_first_frame_conditioning_v22(
|
||||||
|
latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae
|
||||||
|
)
|
||||||
|
|
||||||
|
output = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
),
|
||||||
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
),
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
num_frames=gen_config.num_frames,
|
||||||
|
generator=generator,
|
||||||
|
return_dict=False,
|
||||||
|
output_type="pil",
|
||||||
|
noise_mask=noise_mask,
|
||||||
|
**extra,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# shape = [1, frames, channels, height, width]
|
||||||
|
batch_item = output[0] # list of pil images
|
||||||
|
if gen_config.num_frames > 1:
|
||||||
|
return batch_item # return the frames.
|
||||||
|
else:
|
||||||
|
# get just the first image
|
||||||
|
img = batch_item[0]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_noise_prediction(
|
||||||
|
self,
|
||||||
|
latent_model_input: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, # 0 to 1000 scale
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
batch: DataLoaderBatchDTO,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# videos come in (bs, num_frames, channels, height, width)
|
||||||
|
# images come in (bs, channels, height, width)
|
||||||
|
|
||||||
|
# for wan, only do i2v for video for now. Images do normal t2i
|
||||||
|
conditioned_latent = latent_model_input
|
||||||
|
noise_mask = None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
frames = batch.tensor
|
||||||
|
if len(frames.shape) == 4:
|
||||||
|
first_frames = frames
|
||||||
|
elif len(frames.shape) == 5:
|
||||||
|
first_frames = frames[:, 0]
|
||||||
|
# Add conditioning using the standalone function
|
||||||
|
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
||||||
|
latent_model_input=latent_model_input.to(
|
||||||
|
self.device_torch, self.torch_dtype
|
||||||
|
),
|
||||||
|
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
||||||
|
vae=self.vae,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||||
|
|
||||||
|
# make the noise mask
|
||||||
|
if noise_mask is None:
|
||||||
|
noise_mask = torch.ones(
|
||||||
|
conditioned_latent.shape,
|
||||||
|
dtype=conditioned_latent.dtype,
|
||||||
|
device=conditioned_latent.device,
|
||||||
|
)
|
||||||
|
# todo write this better
|
||||||
|
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
||||||
|
out_t_chunks = []
|
||||||
|
for t in t_chunks:
|
||||||
|
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||||
|
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||||
|
# batch_size, seq_len
|
||||||
|
temp_ts = temp_ts.unsqueeze(0)
|
||||||
|
out_t_chunks.append(temp_ts)
|
||||||
|
timestep = torch.cat(out_t_chunks, dim=0)
|
||||||
|
|
||||||
|
noise_pred = self.model(
|
||||||
|
hidden_states=conditioned_latent,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
|
return_dict=False,
|
||||||
|
**kwargs,
|
||||||
|
)[0]
|
||||||
|
return noise_pred
|
||||||
263
extensions_built_in/diffusion_models/wan22/wan22_pipeline.py
Normal file
263
extensions_built_in/diffusion_models/wan22/wan22_pipeline.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
from toolkit.basic import flush
|
||||||
|
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||||
|
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
|
||||||
|
import torch
|
||||||
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from typing import List
|
||||||
|
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||||
|
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||||
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22Pipeline(WanPipeline):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
text_encoder: UMT5EncoderModel,
|
||||||
|
transformer: WanTransformer3DModel,
|
||||||
|
vae: AutoencoderKLWan,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||||
|
boundary_ratio: Optional[float] = None,
|
||||||
|
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||||
|
device: torch.device = torch.device("cuda"),
|
||||||
|
aggressive_offload: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
transformer=transformer,
|
||||||
|
transformer_2=transformer_2,
|
||||||
|
boundary_ratio=boundary_ratio,
|
||||||
|
expand_timesteps=expand_timesteps,
|
||||||
|
vae=vae,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self._aggressive_offload = aggressive_offload
|
||||||
|
self._exec_device = device
|
||||||
|
@property
|
||||||
|
def _execution_device(self):
|
||||||
|
return self._exec_device
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self: WanPipeline,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
negative_prompt: Union[str, List[str]] = None,
|
||||||
|
height: int = 480,
|
||||||
|
width: int = 832,
|
||||||
|
num_frames: int = 81,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
num_videos_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator,
|
||||||
|
List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "np",
|
||||||
|
return_dict: bool = True,
|
||||||
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[
|
||||||
|
Union[Callable[[int, int, Dict], None],
|
||||||
|
PipelineCallback, MultiPipelineCallbacks]
|
||||||
|
] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
noise_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||||
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
|
# unload vae and transformer
|
||||||
|
vae_device = self.vae.device
|
||||||
|
transformer_device = self.transformer.device
|
||||||
|
text_encoder_device = self.text_encoder.device
|
||||||
|
device = self.transformer.device
|
||||||
|
|
||||||
|
if self._aggressive_offload:
|
||||||
|
print("Unloading vae")
|
||||||
|
self.vae.to("cpu")
|
||||||
|
self.text_encoder.to(device)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
negative_prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
callback_on_step_end_tensor_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._attention_kwargs = attention_kwargs
|
||||||
|
self._current_timestep = None
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if self._aggressive_offload:
|
||||||
|
# unload text encoder
|
||||||
|
print("Unloading text encoder")
|
||||||
|
self.text_encoder.to("cpu")
|
||||||
|
self.transformer.to(device)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
transformer_dtype = self.transformer.dtype
|
||||||
|
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
|
||||||
|
if negative_prompt_embeds is not None:
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||||
|
device, transformer_dtype)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_videos_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
num_frames,
|
||||||
|
torch.float32,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = noise_mask
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - \
|
||||||
|
num_inference_steps * self.scheduler.order
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
latent_model_input = latents.to(device, transformer_dtype)
|
||||||
|
if self.config.expand_timesteps:
|
||||||
|
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||||
|
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||||
|
# batch_size, seq_len
|
||||||
|
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||||
|
else:
|
||||||
|
timestep = t.expand(latents.shape[0])
|
||||||
|
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
attention_kwargs=attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if self.do_classifier_free_guidance:
|
||||||
|
noise_uncond = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
attention_kwargs=attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
noise_pred = noise_uncond + guidance_scale * \
|
||||||
|
(noise_pred - noise_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(
|
||||||
|
noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
# apply i2v mask
|
||||||
|
latents = (latent_model_input * (1 - mask)) + (
|
||||||
|
latents * mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(
|
||||||
|
self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop(
|
||||||
|
"prompt_embeds", prompt_embeds)
|
||||||
|
negative_prompt_embeds = callback_outputs.pop(
|
||||||
|
"negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
|
if self._aggressive_offload:
|
||||||
|
# unload transformer
|
||||||
|
print("Unloading transformer")
|
||||||
|
self.transformer.to("cpu")
|
||||||
|
if self.transformer_2 is not None:
|
||||||
|
self.transformer_2.to("cpu")
|
||||||
|
# load vae
|
||||||
|
print("Loading Vae")
|
||||||
|
self.vae.to(vae_device)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(self.vae.config.latents_mean)
|
||||||
|
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||||
|
latents.device, latents.dtype
|
||||||
|
)
|
||||||
|
latents = latents / latents_std + latents_mean
|
||||||
|
video = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
video = self.video_processor.postprocess_video(
|
||||||
|
video, output_type=output_type)
|
||||||
|
else:
|
||||||
|
video = latents
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (video,)
|
||||||
|
|
||||||
|
return WanPipelineOutput(frames=video)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
torchao==0.10.0
|
torchao==0.10.0
|
||||||
safetensors
|
safetensors
|
||||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||||
git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76
|
git+https://github.com/huggingface/diffusers@56d438727036b0918b30bbe3110c5fe1634ed19d
|
||||||
transformers==4.52.4
|
transformers==4.52.4
|
||||||
lycoris-lora==1.8.3
|
lycoris-lora==1.8.3
|
||||||
flatten_json
|
flatten_json
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import torch.nn.functional as F
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.loaders import FromOriginalModelMixin
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.accelerate_utils import apply_forward_hook
|
from diffusers.utils.accelerate_utils import apply_forward_hook
|
||||||
from diffusers.models.activations import get_activation
|
from diffusers.models.activations import get_activation
|
||||||
@@ -34,6 +35,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|
||||||
|
class AvgDown3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert in_channels * self.factor % out_channels == 0
|
||||||
|
self.group_size = in_channels * self.factor // out_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||||
|
pad = (0, 0, 0, 0, pad_t, 0)
|
||||||
|
x = F.pad(x, pad)
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
T // self.factor_t,
|
||||||
|
self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C * self.factor,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
self.out_channels,
|
||||||
|
self.group_size,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.mean(dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DupUp3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert out_channels * self.factor % in_channels == 0
|
||||||
|
self.repeats = out_channels * self.factor // in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||||
|
x = x.repeat_interleave(self.repeats, dim=1)
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
self.factor_t,
|
||||||
|
self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
x.size(2),
|
||||||
|
x.size(3),
|
||||||
|
x.size(4),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
x.size(2) * self.factor_t,
|
||||||
|
x.size(4) * self.factor_s,
|
||||||
|
x.size(6) * self.factor_s,
|
||||||
|
)
|
||||||
|
if first_chunk:
|
||||||
|
x = x[:, :, self.factor_t - 1:, :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
class WanCausalConv3d(nn.Conv3d):
|
class WanCausalConv3d(nn.Conv3d):
|
||||||
r"""
|
r"""
|
||||||
A custom 3D causal convolution layer with feature caching support.
|
A custom 3D causal convolution layer with feature caching support.
|
||||||
@@ -134,19 +233,23 @@ class WanResample(nn.Module):
|
|||||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, mode: str) -> None:
|
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
|
# default to dim //2
|
||||||
|
if upsample_out_dim is None:
|
||||||
|
upsample_out_dim = dim // 2
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
if mode == "upsample2d":
|
if mode == "upsample2d":
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||||
)
|
)
|
||||||
elif mode == "upsample3d":
|
elif mode == "upsample3d":
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||||
)
|
)
|
||||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
|
||||||
@@ -363,6 +466,48 @@ class WanMidBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanResidualDownBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
num_res_blocks,
|
||||||
|
temperal_downsample=False,
|
||||||
|
down_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Shortcut path with downsample
|
||||||
|
self.avg_shortcut = AvgDown3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_downsample else 1,
|
||||||
|
factor_s=2 if down_flag else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main path with residual blocks and downsample
|
||||||
|
resnets = []
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
# Add the final downsample block
|
||||||
|
if down_flag:
|
||||||
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||||
|
self.downsampler = WanResample(out_dim, mode=mode)
|
||||||
|
else:
|
||||||
|
self.downsampler = None
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
x_copy = x.clone()
|
||||||
|
for resnet in self.resnets:
|
||||||
|
x = resnet(x, feat_cache, feat_idx)
|
||||||
|
if self.downsampler is not None:
|
||||||
|
x = self.downsampler(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
return x + self.avg_shortcut(x_copy)
|
||||||
|
|
||||||
class WanEncoder3d(nn.Module):
|
class WanEncoder3d(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
A 3D encoder module.
|
A 3D encoder module.
|
||||||
@@ -380,6 +525,7 @@ class WanEncoder3d(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
dim=128,
|
dim=128,
|
||||||
z_dim=4,
|
z_dim=4,
|
||||||
dim_mult=[1, 2, 4, 4],
|
dim_mult=[1, 2, 4, 4],
|
||||||
@@ -388,6 +534,7 @@ class WanEncoder3d(nn.Module):
|
|||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
|
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -403,23 +550,35 @@ class WanEncoder3d(nn.Module):
|
|||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
# downsample blocks
|
||||||
self.down_blocks = nn.ModuleList([])
|
self.down_blocks = nn.ModuleList([])
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
# residual (+attention) blocks
|
# residual (+attention) blocks
|
||||||
for _ in range(num_res_blocks):
|
if is_residual:
|
||||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
self.down_blocks.append(
|
||||||
if scale in attn_scales:
|
WanResidualDownBlock(
|
||||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
in_dim,
|
||||||
in_dim = out_dim
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
num_res_blocks,
|
||||||
|
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
||||||
|
down_flag=i != len(dim_mult) - 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
if scale in attn_scales:
|
||||||
|
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
# downsample block
|
# downsample block
|
||||||
if i != len(dim_mult) - 1:
|
if i != len(dim_mult) - 1:
|
||||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||||
scale /= 2.0
|
scale /= 2.0
|
||||||
|
|
||||||
# middle blocks
|
# middle blocks
|
||||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||||
@@ -469,6 +628,92 @@ class WanEncoder3d(nn.Module):
|
|||||||
x = self.conv_out(x)
|
x = self.conv_out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class WanResidualUpBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A block that handles upsampling for the WanVAE decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dim (int): Input dimension
|
||||||
|
out_dim (int): Output dimension
|
||||||
|
num_res_blocks (int): Number of residual blocks
|
||||||
|
dropout (float): Dropout rate
|
||||||
|
temperal_upsample (bool): Whether to upsample on temporal dimension
|
||||||
|
up_flag (bool): Whether to upsample or not
|
||||||
|
non_linearity (str): Type of non-linearity to use
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
temperal_upsample: bool = False,
|
||||||
|
up_flag: bool = False,
|
||||||
|
non_linearity: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
if up_flag:
|
||||||
|
self.avg_shortcut = DupUp3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_upsample else 1,
|
||||||
|
factor_s=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.avg_shortcut = None
|
||||||
|
|
||||||
|
# create residual blocks
|
||||||
|
resnets = []
|
||||||
|
current_dim = in_dim
|
||||||
|
for _ in range(num_res_blocks + 1):
|
||||||
|
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||||
|
current_dim = out_dim
|
||||||
|
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
# Add upsampling layer if needed
|
||||||
|
if up_flag:
|
||||||
|
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||||
|
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
||||||
|
else:
|
||||||
|
self.upsampler = None
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
"""
|
||||||
|
Forward pass through the upsampling block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
feat_cache (list, optional): Feature cache for causal convolutions
|
||||||
|
feat_idx (list, optional): Feature index for cache management
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor
|
||||||
|
"""
|
||||||
|
x_copy = x.clone()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = resnet(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = resnet(x)
|
||||||
|
|
||||||
|
if self.upsampler is not None:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = self.upsampler(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
if self.avg_shortcut is not None:
|
||||||
|
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
class WanUpBlock(nn.Module):
|
class WanUpBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -513,7 +758,7 @@ class WanUpBlock(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
||||||
"""
|
"""
|
||||||
Forward pass through the upsampling block.
|
Forward pass through the upsampling block.
|
||||||
|
|
||||||
@@ -564,6 +809,8 @@ class WanDecoder3d(nn.Module):
|
|||||||
temperal_upsample=[False, True, True],
|
temperal_upsample=[False, True, True],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
non_linearity: str = "silu",
|
non_linearity: str = "silu",
|
||||||
|
out_channels: int = 3,
|
||||||
|
is_residual: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -577,7 +824,6 @@ class WanDecoder3d(nn.Module):
|
|||||||
|
|
||||||
# dimensions
|
# dimensions
|
||||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
@@ -589,36 +835,47 @@ class WanDecoder3d(nn.Module):
|
|||||||
self.up_blocks = nn.ModuleList([])
|
self.up_blocks = nn.ModuleList([])
|
||||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
# residual (+attention) blocks
|
# residual (+attention) blocks
|
||||||
if i > 0:
|
if i > 0 and not is_residual:
|
||||||
|
# wan vae 2.1
|
||||||
in_dim = in_dim // 2
|
in_dim = in_dim // 2
|
||||||
|
|
||||||
# Determine if we need upsampling
|
# determine if we need upsampling
|
||||||
|
up_flag = i != len(dim_mult) - 1
|
||||||
|
# determine upsampling mode, if not upsampling, set to None
|
||||||
upsample_mode = None
|
upsample_mode = None
|
||||||
if i != len(dim_mult) - 1:
|
if up_flag and temperal_upsample[i]:
|
||||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
upsample_mode = "upsample3d"
|
||||||
|
elif up_flag:
|
||||||
|
upsample_mode = "upsample2d"
|
||||||
# Create and add the upsampling block
|
# Create and add the upsampling block
|
||||||
up_block = WanUpBlock(
|
if is_residual:
|
||||||
in_dim=in_dim,
|
up_block = WanResidualUpBlock(
|
||||||
out_dim=out_dim,
|
in_dim=in_dim,
|
||||||
num_res_blocks=num_res_blocks,
|
out_dim=out_dim,
|
||||||
dropout=dropout,
|
num_res_blocks=num_res_blocks,
|
||||||
upsample_mode=upsample_mode,
|
dropout=dropout,
|
||||||
non_linearity=non_linearity,
|
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
||||||
)
|
up_flag= up_flag,
|
||||||
|
non_linearity=non_linearity,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block = WanUpBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
dropout=dropout,
|
||||||
|
upsample_mode=upsample_mode,
|
||||||
|
non_linearity=non_linearity,
|
||||||
|
)
|
||||||
self.up_blocks.append(up_block)
|
self.up_blocks.append(up_block)
|
||||||
|
|
||||||
# Update scale for next iteration
|
|
||||||
if upsample_mode is not None:
|
|
||||||
scale *= 2.0
|
|
||||||
|
|
||||||
# output blocks
|
# output blocks
|
||||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@@ -633,20 +890,11 @@ class WanDecoder3d(nn.Module):
|
|||||||
x = self.conv_in(x)
|
x = self.conv_in(x)
|
||||||
|
|
||||||
## middle
|
## middle
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
x = self.mid_block(x, feat_cache, feat_idx)
|
||||||
# middle
|
|
||||||
x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx)
|
|
||||||
|
|
||||||
## upsamples
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx)
|
|
||||||
|
|
||||||
else:
|
|
||||||
x = self.mid_block(x, feat_cache, feat_idx)
|
|
||||||
|
|
||||||
## upsamples
|
## upsamples
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
x = up_block(x, feat_cache, feat_idx)
|
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
|
||||||
|
|
||||||
## head
|
## head
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
@@ -665,7 +913,46 @@ class WanDecoder3d(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
def patchify(x, patch_size):
|
||||||
|
# YiYi TODO: refactor this
|
||||||
|
from einops import rearrange
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(x, patch_size):
|
||||||
|
# YiYi TODO: refactor this
|
||||||
|
from einops import rearrange
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||||
r"""
|
r"""
|
||||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||||
Introduced in [Wan 2.1].
|
Introduced in [Wan 2.1].
|
||||||
@@ -674,12 +961,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
for all models (such as downloading or saving).
|
for all models (such as downloading or saving).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = False
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_dim: int = 96,
|
base_dim: int = 96,
|
||||||
|
decoder_base_dim: Optional[int] = None,
|
||||||
z_dim: int = 16,
|
z_dim: int = 16,
|
||||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||||
num_res_blocks: int = 2,
|
num_res_blocks: int = 2,
|
||||||
@@ -722,6 +1010,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
2.8251,
|
2.8251,
|
||||||
1.9160,
|
1.9160,
|
||||||
],
|
],
|
||||||
|
is_residual: bool = False,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 3,
|
||||||
|
patch_size: Optional[int] = None,
|
||||||
|
scale_factor_temporal: Optional[int] = 4,
|
||||||
|
scale_factor_spatial: Optional[int] = 8,
|
||||||
|
clip_output: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -729,37 +1024,119 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
self.temperal_downsample = temperal_downsample
|
self.temperal_downsample = temperal_downsample
|
||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
if decoder_base_dim is None:
|
||||||
|
decoder_base_dim = base_dim
|
||||||
|
|
||||||
self.encoder = WanEncoder3d(
|
self.encoder = WanEncoder3d(
|
||||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
|
||||||
)
|
)
|
||||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||||
|
|
||||||
self.decoder = WanDecoder3d(
|
self.decoder = WanDecoder3d(
|
||||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear_cache(self):
|
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||||
def _count_conv3d(model):
|
|
||||||
count = 0
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, WanCausalConv3d):
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
self._conv_num = _count_conv3d(self.decoder)
|
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||||
|
# to perform decoding of a single video latent at a time.
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
|
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||||
|
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||||
|
# intermediate tiles together, the memory requirement can be lowered.
|
||||||
|
self.use_tiling = False
|
||||||
|
|
||||||
|
# The minimal tile height and width for spatial tiling to be used
|
||||||
|
self.tile_sample_min_height = 256
|
||||||
|
self.tile_sample_min_width = 256
|
||||||
|
|
||||||
|
# The minimal distance between two spatial tiles
|
||||||
|
self.tile_sample_stride_height = 192
|
||||||
|
self.tile_sample_stride_width = 192
|
||||||
|
|
||||||
|
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
||||||
|
self._cached_conv_counts = {
|
||||||
|
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
||||||
|
if self.decoder is not None
|
||||||
|
else 0,
|
||||||
|
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
||||||
|
if self.encoder is not None
|
||||||
|
else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def enable_tiling(
|
||||||
|
self,
|
||||||
|
tile_sample_min_height: Optional[int] = None,
|
||||||
|
tile_sample_min_width: Optional[int] = None,
|
||||||
|
tile_sample_stride_height: Optional[float] = None,
|
||||||
|
tile_sample_stride_width: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||||
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||||
|
processing larger images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tile_sample_min_height (`int`, *optional*):
|
||||||
|
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||||
|
tile_sample_min_width (`int`, *optional*):
|
||||||
|
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||||
|
tile_sample_stride_height (`int`, *optional*):
|
||||||
|
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||||
|
no tiling artifacts produced across the height dimension.
|
||||||
|
tile_sample_stride_width (`int`, *optional*):
|
||||||
|
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||||
|
artifacts produced across the width dimension.
|
||||||
|
"""
|
||||||
|
self.use_tiling = True
|
||||||
|
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||||
|
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||||
|
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||||
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||||
|
|
||||||
|
def disable_tiling(self) -> None:
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.use_tiling = False
|
||||||
|
|
||||||
|
def enable_slicing(self) -> None:
|
||||||
|
r"""
|
||||||
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||||
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||||
|
"""
|
||||||
|
self.use_slicing = True
|
||||||
|
|
||||||
|
def disable_slicing(self) -> None:
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
||||||
|
self._conv_num = self._cached_conv_counts["decoder"]
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
self._feat_map = [None] * self._conv_num
|
self._feat_map = [None] * self._conv_num
|
||||||
# cache encode
|
# cache encode
|
||||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
self._enc_feat_map = [None] * self._enc_conv_num
|
||||||
|
|
||||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
def _encode(self, x: torch.Tensor):
|
||||||
|
_, _, num_frame, height, width = x.shape
|
||||||
|
|
||||||
|
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||||
|
return self.tiled_encode(x)
|
||||||
|
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
## cache
|
if self.config.patch_size is not None:
|
||||||
t = x.shape[2]
|
x = patchify(x, patch_size=self.config.patch_size)
|
||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (num_frame - 1) // 4
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -773,8 +1150,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|
||||||
enc = self.quant_conv(out)
|
enc = self.quant_conv(out)
|
||||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
|
||||||
enc = torch.cat([mu, logvar], dim=1)
|
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
return enc
|
return enc
|
||||||
|
|
||||||
@@ -794,27 +1169,39 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||||
"""
|
"""
|
||||||
h = self._encode(x)
|
if self.use_slicing and x.shape[0] > 1:
|
||||||
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||||
|
h = torch.cat(encoded_slices)
|
||||||
|
else:
|
||||||
|
h = self._encode(x)
|
||||||
posterior = DiagonalGaussianDistribution(h)
|
posterior = DiagonalGaussianDistribution(h)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (posterior,)
|
return (posterior,)
|
||||||
return AutoencoderKLOutput(latent_dist=posterior)
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||||
self.clear_cache()
|
_, _, num_frame, height, width = z.shape
|
||||||
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||||
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||||
|
|
||||||
iter_ = z.shape[2]
|
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||||
|
return self.tiled_decode(z, return_dict=return_dict)
|
||||||
|
|
||||||
|
self.clear_cache()
|
||||||
x = self.post_quant_conv(z)
|
x = self.post_quant_conv(z)
|
||||||
for i in range(iter_):
|
for i in range(num_frame):
|
||||||
|
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|
||||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
if self.config.clip_output:
|
||||||
|
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||||
|
if self.config.patch_size is not None:
|
||||||
|
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (out,)
|
return (out,)
|
||||||
@@ -836,12 +1223,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
returned.
|
returned.
|
||||||
"""
|
"""
|
||||||
decoded = self._decode(z).sample
|
if self.use_slicing and z.shape[0] > 1:
|
||||||
|
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||||
|
decoded = torch.cat(decoded_slices)
|
||||||
|
else:
|
||||||
|
decoded = self._decode(z).sample
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (decoded,)
|
return (decoded,)
|
||||||
|
|
||||||
return DecoderOutput(sample=decoded)
|
return DecoderOutput(sample=decoded)
|
||||||
|
|
||||||
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||||
|
y / blend_extent
|
||||||
|
)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||||
|
x / blend_extent
|
||||||
|
)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`): Input batch of videos.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The latent representation of the encoded videos.
|
||||||
|
"""
|
||||||
|
_, _, num_frames, height, width = x.shape
|
||||||
|
latent_height = height // self.spatial_compression_ratio
|
||||||
|
latent_width = width // self.spatial_compression_ratio
|
||||||
|
|
||||||
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||||
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||||
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||||
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||||
|
|
||||||
|
blend_height = tile_latent_min_height - tile_latent_stride_height
|
||||||
|
blend_width = tile_latent_min_width - tile_latent_stride_width
|
||||||
|
|
||||||
|
# Split x into overlapping tiles and encode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, height, self.tile_sample_stride_height):
|
||||||
|
row = []
|
||||||
|
for j in range(0, width, self.tile_sample_stride_width):
|
||||||
|
self.clear_cache()
|
||||||
|
time = []
|
||||||
|
frame_range = 1 + (num_frames - 1) // 4
|
||||||
|
for k in range(frame_range):
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
if k == 0:
|
||||||
|
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||||
|
else:
|
||||||
|
tile = x[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
1 + 4 * (k - 1) : 1 + 4 * k,
|
||||||
|
i : i + self.tile_sample_min_height,
|
||||||
|
j : j + self.tile_sample_min_width,
|
||||||
|
]
|
||||||
|
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
time.append(tile)
|
||||||
|
row.append(torch.cat(time, dim=2))
|
||||||
|
rows.append(row)
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||||
|
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=-1))
|
||||||
|
|
||||||
|
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||||
|
return enc
|
||||||
|
|
||||||
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Decode a batch of images using a tiled decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (`torch.Tensor`): Input batch of latent vectors.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
_, _, num_frames, height, width = z.shape
|
||||||
|
sample_height = height * self.spatial_compression_ratio
|
||||||
|
sample_width = width * self.spatial_compression_ratio
|
||||||
|
|
||||||
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||||
|
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||||
|
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||||
|
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||||
|
|
||||||
|
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
||||||
|
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
||||||
|
|
||||||
|
# Split z into overlapping tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, height, tile_latent_stride_height):
|
||||||
|
row = []
|
||||||
|
for j in range(0, width, tile_latent_stride_width):
|
||||||
|
self.clear_cache()
|
||||||
|
time = []
|
||||||
|
for k in range(num_frames):
|
||||||
|
self._conv_idx = [0]
|
||||||
|
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||||
|
time.append(decoded)
|
||||||
|
row.append(torch.cat(time, dim=2))
|
||||||
|
rows.append(row)
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||||
|
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=-1))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
@@ -862,4 +1398,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
z = posterior.mode()
|
z = posterior.mode()
|
||||||
dec = self.decode(z, return_dict=return_dict)
|
dec = self.decode(z, return_dict=return_dict)
|
||||||
return dec
|
return dec
|
||||||
@@ -89,12 +89,18 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
|||||||
transformer: WanTransformer3DModel,
|
transformer: WanTransformer3DModel,
|
||||||
vae: AutoencoderKLWan,
|
vae: AutoencoderKLWan,
|
||||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||||
|
boundary_ratio: Optional[float] = None,
|
||||||
|
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||||
device: torch.device = torch.device("cuda"),
|
device: torch.device = torch.device("cuda"),
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
|
transformer_2=transformer_2,
|
||||||
|
boundary_ratio=boundary_ratio,
|
||||||
|
expand_timesteps=expand_timesteps,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
@@ -300,6 +306,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
|||||||
|
|
||||||
class Wan21(BaseModel):
|
class Wan21(BaseModel):
|
||||||
arch = 'wan21'
|
arch = 'wan21'
|
||||||
|
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||||
|
_wan_expand_timesteps = False
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device,
|
device,
|
||||||
@@ -331,7 +339,7 @@ class Wan21(BaseModel):
|
|||||||
dtype = self.torch_dtype
|
dtype = self.torch_dtype
|
||||||
model_path = self.model_config.name_or_path
|
model_path = self.model_config.name_or_path
|
||||||
|
|
||||||
self.print_and_status_update("Loading Wan2.1 model")
|
self.print_and_status_update("Loading Wan model")
|
||||||
subfolder = 'transformer'
|
subfolder = 'transformer'
|
||||||
transformer_path = model_path
|
transformer_path = model_path
|
||||||
if os.path.exists(transformer_path):
|
if os.path.exists(transformer_path):
|
||||||
@@ -380,7 +388,6 @@ class Wan21(BaseModel):
|
|||||||
# patch the state dict method
|
# patch the state dict method
|
||||||
patch_dequantization_on_save(transformer)
|
patch_dequantization_on_save(transformer)
|
||||||
quantization_type = get_qtype(self.model_config.qtype)
|
quantization_type = get_qtype(self.model_config.qtype)
|
||||||
self.print_and_status_update("Quantizing transformer")
|
|
||||||
if self.model_config.low_vram:
|
if self.model_config.low_vram:
|
||||||
print("Quantizing blocks")
|
print("Quantizing blocks")
|
||||||
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||||
@@ -474,22 +481,26 @@ class Wan21(BaseModel):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def get_generation_pipeline(self):
|
def get_generation_pipeline(self):
|
||||||
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
|
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||||
if self.model_config.low_vram:
|
if self.model_config.low_vram:
|
||||||
pipeline = AggressiveWanUnloadPipeline(
|
pipeline = AggressiveWanUnloadPipeline(
|
||||||
vae=self.vae,
|
vae=self.vae,
|
||||||
transformer=self.model,
|
transformer=self.model,
|
||||||
|
transformer_2=self.model,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
expand_timesteps=self._wan_expand_timesteps,
|
||||||
device=self.device_torch
|
device=self.device_torch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pipeline = WanPipeline(
|
pipeline = WanPipeline(
|
||||||
vae=self.vae,
|
vae=self.vae,
|
||||||
transformer=self.unet,
|
transformer=self.unet,
|
||||||
|
transformer_2=self.unet,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
|
expand_timesteps=self._wan_expand_timesteps,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -48,11 +48,13 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
self,
|
self,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
text_encoder: UMT5EncoderModel,
|
text_encoder: UMT5EncoderModel,
|
||||||
image_encoder: CLIPVisionModel,
|
|
||||||
image_processor: CLIPImageProcessor,
|
|
||||||
transformer: WanTransformer3DModel,
|
transformer: WanTransformer3DModel,
|
||||||
vae: AutoencoderKLWan,
|
vae: AutoencoderKLWan,
|
||||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
image_processor: CLIPImageProcessor = None,
|
||||||
|
image_encoder: CLIPVisionModel = None,
|
||||||
|
transformer_2: WanTransformer3DModel = None,
|
||||||
|
boundary_ratio: Optional[float] = None,
|
||||||
device: torch.device = torch.device("cuda"),
|
device: torch.device = torch.device("cuda"),
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -63,6 +65,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
image_processor=image_processor,
|
image_processor=image_processor,
|
||||||
|
transformer_2=transformer_2,
|
||||||
|
boundary_ratio=boundary_ratio,
|
||||||
)
|
)
|
||||||
self._exec_device = device
|
self._exec_device = device
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def add_first_frame_conditioning(
|
|||||||
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
||||||
|
|
||||||
# resize first frame to match the latent model input
|
# resize first frame to match the latent model input
|
||||||
vae_scale_factor = 8
|
vae_scale_factor = vae.config.scale_factor_spatial
|
||||||
first_frame = F.interpolate(
|
first_frame = F.interpolate(
|
||||||
first_frame,
|
first_frame,
|
||||||
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
||||||
@@ -111,3 +111,55 @@ def add_first_frame_conditioning(
|
|||||||
[latent_model_input, first_frame_condition], dim=1)
|
[latent_model_input, first_frame_condition], dim=1)
|
||||||
|
|
||||||
return conditioned_latent
|
return conditioned_latent
|
||||||
|
|
||||||
|
|
||||||
|
def add_first_frame_conditioning_v22(
|
||||||
|
latent_model_input,
|
||||||
|
first_frame,
|
||||||
|
vae
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
||||||
|
and returns the modified latent + binary mask (0=conditioned, 1=noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
|
||||||
|
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
|
||||||
|
vae: VAE model with .encode() and .config.latents_mean/std
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
latent: (bs, 48, T, H, W) - modified input latent
|
||||||
|
mask: (bs, 1, T, H, W) - binary mask
|
||||||
|
"""
|
||||||
|
device = latent_model_input.device
|
||||||
|
dtype = latent_model_input.dtype
|
||||||
|
bs, _, T, H, W = latent_model_input.shape
|
||||||
|
scale = vae.config.scale_factor_spatial
|
||||||
|
target_h = H * scale
|
||||||
|
target_w = W * scale
|
||||||
|
|
||||||
|
# Ensure shape
|
||||||
|
if first_frame.ndim == 3:
|
||||||
|
first_frame = first_frame.unsqueeze(0)
|
||||||
|
if first_frame.shape[0] != bs:
|
||||||
|
first_frame = first_frame.expand(bs, -1, -1, -1)
|
||||||
|
|
||||||
|
# Resize and encode
|
||||||
|
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||||
|
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
|
||||||
|
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||||
|
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||||
|
encoded = (encoded - mean) * std
|
||||||
|
|
||||||
|
# Replace in latent
|
||||||
|
latent = latent_model_input.clone()
|
||||||
|
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
|
||||||
|
|
||||||
|
# Mask: 0 where conditioned, 1 otherwise
|
||||||
|
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
||||||
|
mask[:, :, :encoded.shape[2]] = 0.0
|
||||||
|
|
||||||
|
return latent, mask
|
||||||
@@ -181,6 +181,27 @@ export const modelArchs: ModelArch[] = [
|
|||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: 'wan22_5b',
|
||||||
|
label: 'Wan 2.2 TI2V (5B)',
|
||||||
|
group: 'video',
|
||||||
|
isVideoModel: true,
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath],
|
||||||
|
'config.process[0].model.quantize': [true, false],
|
||||||
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
|
'config.process[0].model.low_vram': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].sample.num_frames': [121, 1],
|
||||||
|
'config.process[0].sample.fps': [24, 1],
|
||||||
|
'config.process[0].sample.width': [768, 1024],
|
||||||
|
'config.process[0].sample.height': [768, 1024],
|
||||||
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
label: 'Lumina2',
|
label: 'Lumina2',
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.3.13"
|
VERSION = "0.3.14"
|
||||||
Reference in New Issue
Block a user