mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add support for Wan2.2 5B
This commit is contained in:
@@ -3,13 +3,15 @@ from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan22Model
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
ChromaModel,
|
||||
HidreamModel,
|
||||
HidreamE1Model,
|
||||
FLiteModel,
|
||||
OmniGen2Model,
|
||||
FluxKontextModel
|
||||
ChromaModel,
|
||||
HidreamModel,
|
||||
HidreamE1Model,
|
||||
FLiteModel,
|
||||
OmniGen2Model,
|
||||
FluxKontextModel,
|
||||
Wan22Model,
|
||||
]
|
||||
|
||||
1
extensions_built_in/diffusion_models/wan22/__init__.py
Normal file
1
extensions_built_in/diffusion_models/wan22/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .wan22_model import Wan22Model
|
||||
259
extensions_built_in/diffusion_models/wan22/wan22_model.py
Normal file
259
extensions_built_in/diffusion_models/wan22/wan22_model.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import torch
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from .wan22_pipeline import Wan22Pipeline
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline
|
||||
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22
|
||||
|
||||
|
||||
# for generation only?
|
||||
scheduler_configUniPC = {
|
||||
"_class_name": "UniPCMultistepScheduler",
|
||||
"_diffusers_version": "0.35.0.dev0",
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"disable_corrector": [],
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"final_sigmas_type": "zero",
|
||||
"flow_shift": 5.0,
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"predict_x0": True,
|
||||
"prediction_type": "flow_prediction",
|
||||
"rescale_betas_zero_snr": False,
|
||||
"sample_max_value": 1.0,
|
||||
"solver_order": 2,
|
||||
"solver_p": None,
|
||||
"solver_type": "bh2",
|
||||
"steps_offset": 0,
|
||||
"thresholding": False,
|
||||
"time_shift_type": "exponential",
|
||||
"timestep_spacing": "linspace",
|
||||
"trained_betas": None,
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": False,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_flow_sigmas": True,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
# for training. I think it is right
|
||||
scheduler_config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 5.0,
|
||||
"use_dynamic_shifting": False,
|
||||
}
|
||||
|
||||
|
||||
class Wan22Model(Wan21):
|
||||
arch = "wan22_5b"
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device=device,
|
||||
model_config=model_config,
|
||||
dtype=dtype,
|
||||
custom_pipeline=custom_pipeline,
|
||||
noise_scheduler=noise_scheduler,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._wan_cache = None
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# 16x compression and 2x2 patch size
|
||||
return 32
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
pipeline = Wan22Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model,
|
||||
transformer_2=self.model,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
expand_timesteps=self._wan_expand_timesteps,
|
||||
device=self.device_torch,
|
||||
aggressive_offload=self.model_config.low_vram,
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.2_5b"
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: AggressiveWanUnloadPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
|
||||
num_frames = (
|
||||
(gen_config.num_frames - 1) // 4
|
||||
) * 4 + 1 # make sure it is divisible by 4 + 1
|
||||
gen_config.num_frames = num_frames
|
||||
|
||||
height = gen_config.height
|
||||
width = gen_config.width
|
||||
noise_mask = None
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
|
||||
d = self.get_bucket_divisibility()
|
||||
|
||||
# make sure they are divisible by d
|
||||
height = height // d * d
|
||||
width = width // d * d
|
||||
|
||||
# resize the control image
|
||||
control_img = control_img.resize((width, height), Image.LANCZOS)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = pipeline.prepare_latents(
|
||||
1,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
gen_config.num_frames,
|
||||
torch.float32,
|
||||
self.device_torch,
|
||||
generator,
|
||||
None,
|
||||
).to(self.torch_dtype)
|
||||
|
||||
first_frame_n1p1 = (
|
||||
TF.to_tensor(control_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.device_torch, dtype=self.torch_dtype)
|
||||
* 2.0
|
||||
- 1.0
|
||||
) # normalize to [-1, 1]
|
||||
|
||||
gen_config.latents, noise_mask = add_first_frame_conditioning_v22(
|
||||
latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae
|
||||
)
|
||||
|
||||
output = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
num_frames=gen_config.num_frames,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
output_type="pil",
|
||||
noise_mask=noise_mask,
|
||||
**extra,
|
||||
)[0]
|
||||
|
||||
# shape = [1, frames, channels, height, width]
|
||||
batch_item = output[0] # list of pil images
|
||||
if gen_config.num_frames > 1:
|
||||
return batch_item # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = batch_item[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: DataLoaderBatchDTO,
|
||||
**kwargs,
|
||||
):
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
|
||||
# for wan, only do i2v for video for now. Images do normal t2i
|
||||
conditioned_latent = latent_model_input
|
||||
noise_mask = None
|
||||
|
||||
with torch.no_grad():
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
# Add conditioning using the standalone function
|
||||
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
||||
latent_model_input=latent_model_input.to(
|
||||
self.device_torch, self.torch_dtype
|
||||
),
|
||||
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
||||
vae=self.vae,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
|
||||
# make the noise mask
|
||||
if noise_mask is None:
|
||||
noise_mask = torch.ones(
|
||||
conditioned_latent.shape,
|
||||
dtype=conditioned_latent.dtype,
|
||||
device=conditioned_latent.device,
|
||||
)
|
||||
# todo write this better
|
||||
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
||||
out_t_chunks = []
|
||||
for t in t_chunks:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
temp_ts = temp_ts.unsqueeze(0)
|
||||
out_t_chunks.append(temp_ts)
|
||||
timestep = torch.cat(out_t_chunks, dim=0)
|
||||
|
||||
noise_pred = self.model(
|
||||
hidden_states=conditioned_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
return noise_pred
|
||||
263
extensions_built_in/diffusion_models/wan22/wan22_pipeline.py
Normal file
263
extensions_built_in/diffusion_models/wan22/wan22_pipeline.py
Normal file
@@ -0,0 +1,263 @@
|
||||
|
||||
import torch
|
||||
from toolkit.basic import flush
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
|
||||
import torch
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from typing import List
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
class Wan22Pipeline(WanPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
device: torch.device = torch.device("cuda"),
|
||||
aggressive_offload: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
boundary_ratio=boundary_ratio,
|
||||
expand_timesteps=expand_timesteps,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self._aggressive_offload = aggressive_offload
|
||||
self._exec_device = device
|
||||
@property
|
||||
def _execution_device(self):
|
||||
return self._exec_device
|
||||
|
||||
def __call__(
|
||||
self: WanPipeline,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None],
|
||||
PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# unload vae and transformer
|
||||
vae_device = self.vae.device
|
||||
transformer_device = self.transformer.device
|
||||
text_encoder_device = self.text_encoder.device
|
||||
device = self.transformer.device
|
||||
|
||||
if self._aggressive_offload:
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
self.text_encoder.to(device)
|
||||
flush()
|
||||
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self._aggressive_offload:
|
||||
# unload text encoder
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
self.transformer.to(device)
|
||||
flush()
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
device, transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
mask = noise_mask
|
||||
if mask is None:
|
||||
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - \
|
||||
num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(device, transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * \
|
||||
(noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
# apply i2v mask
|
||||
latents = (latent_model_input * (1 - mask)) + (
|
||||
latents * mask
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(
|
||||
self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop(
|
||||
"prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop(
|
||||
"negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if self._aggressive_offload:
|
||||
# unload transformer
|
||||
print("Unloading transformer")
|
||||
self.transformer.to("cpu")
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to("cpu")
|
||||
# load vae
|
||||
print("Loading Vae")
|
||||
self.vae.to(vae_device)
|
||||
flush()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(
|
||||
video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -1,7 +1,7 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76
|
||||
git+https://github.com/huggingface/diffusers@56d438727036b0918b30bbe3110c5fe1634ed19d
|
||||
transformers==4.52.4
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.accelerate_utils import apply_forward_hook
|
||||
from diffusers.models.activations import get_activation
|
||||
@@ -34,6 +35,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1:, :, :]
|
||||
return x
|
||||
|
||||
class WanCausalConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
@@ -134,19 +233,23 @@ class WanResample(nn.Module):
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# default to dim //2
|
||||
if upsample_out_dim is None:
|
||||
upsample_out_dim = dim // 2
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
||||
)
|
||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
@@ -363,6 +466,48 @@ class WanMidBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualDownBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
num_res_blocks,
|
||||
temperal_downsample=False,
|
||||
down_flag=False):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
resnets = []
|
||||
for _ in range(num_res_blocks):
|
||||
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
self.downsampler = WanResample(out_dim, mode=mode)
|
||||
else:
|
||||
self.downsampler = None
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
if self.downsampler is not None:
|
||||
x = self.downsampler(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
class WanEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
@@ -380,6 +525,7 @@ class WanEncoder3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
@@ -388,6 +534,7 @@ class WanEncoder3d(nn.Module):
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -403,23 +550,35 @@ class WanEncoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
if is_residual:
|
||||
self.down_blocks.append(
|
||||
WanResidualDownBlock(
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
num_res_blocks,
|
||||
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
@@ -469,6 +628,92 @@ class WanEncoder3d(nn.Module):
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
class WanResidualUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
temperal_upsample (bool): Whether to upsample on temporal dimension
|
||||
up_flag (bool): Whether to upsample or not
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
temperal_upsample: bool = False,
|
||||
up_flag: bool = False,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# create residual blocks
|
||||
resnets = []
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
if up_flag:
|
||||
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
||||
else:
|
||||
self.upsampler = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
x_copy = x.clone()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsampler is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsampler(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
|
||||
if self.avg_shortcut is not None:
|
||||
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
||||
|
||||
return x
|
||||
|
||||
class WanUpBlock(nn.Module):
|
||||
"""
|
||||
@@ -513,7 +758,7 @@ class WanUpBlock(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
@@ -564,6 +809,8 @@ class WanDecoder3d(nn.Module):
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
out_channels: int = 3,
|
||||
is_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -577,7 +824,6 @@ class WanDecoder3d(nn.Module):
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
@@ -589,36 +835,47 @@ class WanDecoder3d(nn.Module):
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
if i > 0 and not is_residual:
|
||||
# wan vae 2.1
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
# determine if we need upsampling
|
||||
up_flag = i != len(dim_mult) - 1
|
||||
# determine upsampling mode, if not upsampling, set to None
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
if up_flag and temperal_upsample[i]:
|
||||
upsample_mode = "upsample3d"
|
||||
elif up_flag:
|
||||
upsample_mode = "upsample2d"
|
||||
# Create and add the upsampling block
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
if is_residual:
|
||||
up_block = WanResidualUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
||||
up_flag= up_flag,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
else:
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -633,20 +890,11 @@ class WanDecoder3d(nn.Module):
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx)
|
||||
|
||||
else:
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -665,7 +913,46 @@ class WanDecoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
def patchify(x, patch_size):
|
||||
# YiYi TODO: refactor this
|
||||
from einops import rearrange
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
# YiYi TODO: refactor this
|
||||
from einops import rearrange
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
Introduced in [Wan 2.1].
|
||||
@@ -674,12 +961,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
@@ -722,6 +1010,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
2.8251,
|
||||
1.9160,
|
||||
],
|
||||
is_residual: bool = False,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
patch_size: Optional[int] = None,
|
||||
scale_factor_temporal: Optional[int] = 4,
|
||||
scale_factor_spatial: Optional[int] = 8,
|
||||
clip_output: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -729,37 +1024,119 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
if decoder_base_dim is None:
|
||||
decoder_base_dim = base_dim
|
||||
|
||||
self.encoder = WanEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
|
||||
)
|
||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = WanDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
||||
self._cached_conv_counts = {
|
||||
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
||||
if self.decoder is not None
|
||||
else 0,
|
||||
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
||||
if self.encoder is not None
|
||||
else 0,
|
||||
}
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_sample_stride_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension.
|
||||
tile_sample_stride_width (`int`, *optional*):
|
||||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||
artifacts produced across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def clear_cache(self):
|
||||
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
||||
self._conv_num = self._cached_conv_counts["decoder"]
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def _encode(self, x: torch.Tensor):
|
||||
_, _, num_frame, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -773,8 +1150,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
enc = self.quant_conv(out)
|
||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
||||
enc = torch.cat([mu, logvar], dim=1)
|
||||
self.clear_cache()
|
||||
return enc
|
||||
|
||||
@@ -794,27 +1169,39 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self._encode(x)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
self.clear_cache()
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||
_, _, num_frame, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
iter_ = z.shape[2]
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
self.clear_cache()
|
||||
x = self.post_quant_conv(z)
|
||||
for i in range(iter_):
|
||||
|
||||
for i in range(num_frame):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.clip_output:
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.patch_size is not None:
|
||||
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
@@ -836,12 +1223,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
decoded = self._decode(z).sample
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
_, _, num_frames, height, width = x.shape
|
||||
latent_height = height // self.spatial_compression_ratio
|
||||
latent_width = width // self.spatial_compression_ratio
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
blend_height = tile_latent_min_height - tile_latent_stride_height
|
||||
blend_width = tile_latent_min_width - tile_latent_stride_width
|
||||
|
||||
# Split x into overlapping tiles and encode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, self.tile_sample_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, self.tile_sample_stride_width):
|
||||
self.clear_cache()
|
||||
time = []
|
||||
frame_range = 1 + (num_frames - 1) // 4
|
||||
for k in range(frame_range):
|
||||
self._enc_conv_idx = [0]
|
||||
if k == 0:
|
||||
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
else:
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
1 + 4 * (k - 1) : 1 + 4 * k,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
self.clear_cache()
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
_, _, num_frames, height, width = z.shape
|
||||
sample_height = height * self.spatial_compression_ratio
|
||||
sample_width = width * self.spatial_compression_ratio
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, tile_latent_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, tile_latent_stride_width):
|
||||
self.clear_cache()
|
||||
time = []
|
||||
for k in range(num_frames):
|
||||
self._conv_idx = [0]
|
||||
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
time.append(decoded)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
self.clear_cache()
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -862,4 +1398,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
return dec
|
||||
@@ -89,12 +89,18 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
device: torch.device = torch.device("cuda"),
|
||||
):
|
||||
super().__init__(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
boundary_ratio=boundary_ratio,
|
||||
expand_timesteps=expand_timesteps,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
@@ -300,6 +306,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
|
||||
class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
@@ -331,7 +339,7 @@ class Wan21(BaseModel):
|
||||
dtype = self.torch_dtype
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
self.print_and_status_update("Loading Wan2.1 model")
|
||||
self.print_and_status_update("Loading Wan model")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
@@ -380,7 +388,6 @@ class Wan21(BaseModel):
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
if self.model_config.low_vram:
|
||||
print("Quantizing blocks")
|
||||
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||
@@ -474,22 +481,26 @@ class Wan21(BaseModel):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
if self.model_config.low_vram:
|
||||
pipeline = AggressiveWanUnloadPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model,
|
||||
transformer_2=self.model,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
expand_timesteps=self._wan_expand_timesteps,
|
||||
device=self.device_torch
|
||||
)
|
||||
else:
|
||||
pipeline = WanPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
transformer_2=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
expand_timesteps=self._wan_expand_timesteps,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,11 +48,13 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
image_processor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModel = None,
|
||||
transformer_2: WanTransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
):
|
||||
super().__init__(
|
||||
@@ -63,6 +65,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
transformer_2=transformer_2,
|
||||
boundary_ratio=boundary_ratio,
|
||||
)
|
||||
self._exec_device = device
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def add_first_frame_conditioning(
|
||||
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
||||
|
||||
# resize first frame to match the latent model input
|
||||
vae_scale_factor = 8
|
||||
vae_scale_factor = vae.config.scale_factor_spatial
|
||||
first_frame = F.interpolate(
|
||||
first_frame,
|
||||
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
||||
@@ -111,3 +111,55 @@ def add_first_frame_conditioning(
|
||||
[latent_model_input, first_frame_condition], dim=1)
|
||||
|
||||
return conditioned_latent
|
||||
|
||||
|
||||
def add_first_frame_conditioning_v22(
|
||||
latent_model_input,
|
||||
first_frame,
|
||||
vae
|
||||
):
|
||||
"""
|
||||
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
||||
and returns the modified latent + binary mask (0=conditioned, 1=noise).
|
||||
|
||||
Args:
|
||||
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
|
||||
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
|
||||
vae: VAE model with .encode() and .config.latents_mean/std
|
||||
|
||||
Returns:
|
||||
latent: (bs, 48, T, H, W) - modified input latent
|
||||
mask: (bs, 1, T, H, W) - binary mask
|
||||
"""
|
||||
device = latent_model_input.device
|
||||
dtype = latent_model_input.dtype
|
||||
bs, _, T, H, W = latent_model_input.shape
|
||||
scale = vae.config.scale_factor_spatial
|
||||
target_h = H * scale
|
||||
target_w = W * scale
|
||||
|
||||
# Ensure shape
|
||||
if first_frame.ndim == 3:
|
||||
first_frame = first_frame.unsqueeze(0)
|
||||
if first_frame.shape[0] != bs:
|
||||
first_frame = first_frame.expand(bs, -1, -1, -1)
|
||||
|
||||
# Resize and encode
|
||||
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
|
||||
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
|
||||
|
||||
# Normalize
|
||||
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
encoded = (encoded - mean) * std
|
||||
|
||||
# Replace in latent
|
||||
latent = latent_model_input.clone()
|
||||
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
|
||||
|
||||
# Mask: 0 where conditioned, 1 otherwise
|
||||
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
||||
mask[:, :, :encoded.shape[2]] = 0.0
|
||||
|
||||
return latent, mask
|
||||
@@ -181,6 +181,27 @@ export const modelArchs: ModelArch[] = [
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||
},
|
||||
{
|
||||
name: 'wan22_5b',
|
||||
label: 'Wan 2.2 TI2V (5B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [121, 1],
|
||||
'config.process[0].sample.fps': [24, 1],
|
||||
'config.process[0].sample.width': [768, 1024],
|
||||
'config.process[0].sample.height': [768, 1024],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||
},
|
||||
{
|
||||
name: 'lumina2',
|
||||
label: 'Lumina2',
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.13"
|
||||
VERSION = "0.3.14"
|
||||
Reference in New Issue
Block a user