Finished up first frame for i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-12 17:13:04 -06:00
parent cd37ccfc2e
commit 6fb44db6a0
5 changed files with 206 additions and 87 deletions

View File

@@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module):
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
self.cached_control_image_0_1: Optional[torch.Tensor] = None
self.setup_adapter()
@@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module):
def trigger_pre_te(
self,
tensors_0_1: torch.Tensor,
tensors_0_1: Optional[torch.Tensor]=None,
tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader
is_training=False,
has_been_preprocessed=False,
batch_tensor: Optional[torch.Tensor]=None,
quad_count=4,
batch_size=1,
) -> PromptEmbeds:
if tensors_0_1 is not None:
# actual 0 - 1 image
self.cached_control_image_0_1 = tensors_0_1
else:
# image has been processed through the dataloader and is prepped for vision encoder
self.cached_control_image_0_1 = None
if batch_tensor is not None and self.cached_control_image_0_1 is None:
# convert it to 0 - 1
to_cache = batch_tensor / 2 + 0.5
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
# if it is a video, just grad first frame
if len(to_cache.shape) == 5:
to_cache = to_cache[:, 0:1, :, :, :]
to_cache = to_cache.squeeze(1)
self.cached_control_image_0_1 = to_cache
if tensors_preprocessed is not None and has_been_preprocessed:
tensors_0_1 = tensors_preprocessed
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
skip_unconditional = self.sd_ref().is_flux

View File

@@ -9,7 +9,6 @@ from diffusers import WanTransformer3DModel
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
from toolkit.util.shuffle import shuffle_tensor_along_axis
if TYPE_CHECKING:
@@ -98,15 +97,18 @@ class FrameEmbedder(torch.nn.Module):
def __init__(
self,
adapter: 'I2VAdapter',
orig_layer: torch.nn.Linear,
in_channels=64,
out_channels=3072
orig_layer: torch.nn.Conv3d,
in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels
):
super().__init__()
# only do the weight for the new input. We combine with the original linear layer
init = torch.randn(out_channels, in_channels,
device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01
self.weight = torch.nn.Parameter(init)
# goes through a conv patch embedding first and is then flattened
# hidden_states = self.patch_embedding(hidden_states)
# hidden_states = hidden_states.flatten(2).transpose(1, 2)
inner_dim = orig_layer.out_channels
patch_size = adapter.sd_ref().model.config.patch_size
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
@@ -116,35 +118,24 @@ class FrameEmbedder(torch.nn.Module):
cls,
model: WanTransformer3DModel,
adapter: 'I2VAdapter',
num_control_images=1,
has_inpainting_input=False
):
# TODO implement this
if model.__class__.__name__ == 'WanTransformer3DModel':
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels
if has_inpainting_input:
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
# packed it is 64ch + 4ch mask
# so we need to add 4 to the input channels
num_adapter_in_channels += 4
x_embedder: torch.nn.Linear = model.x_embedder
orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding
img_embedder = cls(
adapter,
orig_layer=x_embedder,
in_channels=num_adapter_in_channels,
out_channels=x_embedder.out_features,
orig_layer=orig_patch_embedding,
in_channels=new_channels,
)
# hijack the forward method
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
x_embedder.forward = img_embedder.forward
orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward
orig_patch_embedding.forward = img_embedder.forward
# update the config of the transformer
model.config.in_channels = model.config.in_channels * \
(num_control_images + 1)
model.config["in_channels"] = model.config.in_channels
# update the config of the transformer, only needed when merged in
# model.config.in_channels = model.config.in_channels + new_channels
# model.config["in_channels"] = model.config.in_channels + new_channels
return img_embedder
else:
@@ -159,30 +150,37 @@ class FrameEmbedder(torch.nn.Module):
# make sure lora is not active
if self.adapter_ref().control_lora is not None:
self.adapter_ref().control_lora.is_active = False
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
if x.shape[1] > self.orig_layer_ref().in_channels:
# we have i2v, so we need to remove the extra channels
x = x[:, :self.orig_layer_ref().in_channels, :, :, :]
return self.orig_layer_ref()._orig_i2v_adapter_forward(x)
# make sure lora is active
if self.adapter_ref().control_lora is not None:
self.adapter_ref().control_lora.is_active = True
# x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16)
# (16 + 4 + 16) = 36 channels
# (batch_size, 36, num_frames, latent_height, latent_width)
orig_device = x.device
orig_dtype = x.dtype
orig_in = x[:, :16, :, :, :]
orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in)
# remove original stuff
x = x[:, 16:, :, :, :]
x = x.to(self.weight.device, dtype=self.weight.dtype)
orig_weight = self.orig_layer_ref().weight.data.detach()
orig_weight = orig_weight.to(
self.weight.device, dtype=self.weight.dtype)
linear_weight = torch.cat([orig_weight, self.weight], dim=1)
bias = None
if self.orig_layer_ref().bias is not None:
bias = self.orig_layer_ref().bias.data.detach().to(
self.weight.device, dtype=self.weight.dtype)
x = torch.nn.functional.linear(x, linear_weight, bias)
x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype)
x = self.patch_embedding(x)
x = x.to(orig_device, dtype=orig_dtype)
# add the original out
x = x + orig_out
return x
@@ -299,6 +297,8 @@ def new_wan_forward(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# prevent circular import
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
adapter:'I2VAdapter' = self._i2v_adapter_ref()
if adapter.is_active:
@@ -336,6 +336,30 @@ def new_wan_forward(
# doing a normal training run, always use conditional embeds
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
# add the first frame conditioning
if adapter.frame_embedder is not None:
with torch.no_grad():
# add the first frame conditioning
conditioning_frame = adapter.adapter_ref().cached_control_image_0_1
if conditioning_frame is None:
raise ValueError("No conditioning frame found")
# make it -1 to 1
conditioning_frame = (conditioning_frame * 2) - 1
conditioning_frame = conditioning_frame.to(
hidden_states.device, dtype=hidden_states.dtype
)
# if doing a full denoise, the latent input may be full channels here, only get first 16
if hidden_states.shape[1] > 16:
hidden_states = hidden_states[:, :16, :, :, :]
hidden_states = add_first_frame_conditioning(
latent_model_input=hidden_states,
first_frame=conditioning_frame,
vae=adapter.adapter_ref().sd_ref().vae,
)
else:
# not active deactivate the condition embedder
self.condition_embedder.image_embedder = None
@@ -450,9 +474,7 @@ class I2VAdapter(torch.nn.Module):
if self.config.i2v_do_start_frame:
self.frame_embedder = FrameEmbedder.from_model(
sd.unet,
self,
num_control_images=config.num_control_images,
has_inpainting_input=config.has_inpainting_input
self
)
self.frame_embedder.to(self.device_torch)

View File

@@ -41,6 +41,8 @@ from .wan21 import \
scheduler_config, \
Wan21
from .wan_utils import add_first_frame_conditioning
class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
def __init__(
@@ -269,6 +271,9 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()}
image_embeds = self.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
class Wan21I2V(Wan21):
@@ -488,46 +493,12 @@ class Wan21I2V(Wan21):
image_embeds = image_embeds_full.hidden_states[-2]
image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype)
# condition latent
# first_frames shape is (bs, channels, height, width)
# wan needs latends in (bs, channels, num_frames, height, width)
first_frames = first_frames.unsqueeze(2)
# video condition is first frame is the frame, the rest are zeros
num_frames = frames.shape[1]
zero_frame = torch.zeros_like(first_frames)
video_condition = torch.cat([
first_frames,
*[zero_frame for _ in range(num_frames - 1)]
], dim=2)
# our vae encoder expects (bs, num_frames, channels, height, width)
# permute to (bs, channels, num_frames, height, width)
video_condition = video_condition.permute(0, 2, 1, 3, 4)
latent_condition = self.encode_images(
video_condition,
device=self.device_torch,
dtype=self.torch_dtype,
# Add conditioning using the standalone function
conditioned_latent = add_first_frame_conditioning(
latent_model_input=latent_model_input,
first_frame=first_frames,
vae=self.vae
)
latent_condition = latent_condition.to(self.device_torch, dtype=self.torch_dtype)
batch_size = frames.shape[0]
latent_height = latent_condition.shape[3]
latent_width = latent_condition.shape[4]
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.pipeline.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(batch_size, -1, self.pipeline.vae_scale_factor_temporal, latent_height, latent_width)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(self.device_torch, dtype=self.torch_dtype)
# return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1)
noise_pred = self.model(
hidden_states=conditioned_latent,
@@ -537,4 +508,4 @@ class Wan21I2V(Wan21):
return_dict=False,
**kwargs
)[0]
return noise_pred
return noise_pred

View File

@@ -0,0 +1,102 @@
import torch
import torch.nn.functional as F
def add_first_frame_conditioning(
latent_model_input,
first_frame,
vae
):
"""
Adds first frame conditioning to a video diffusion model input.
Args:
latent_model_input: Original latent input (bs, channels, num_frames, height, width)
first_frame: Tensor of first frame to condition on (bs, channels, height, width)
vae: VAE model for encoding the conditioning
Returns:
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width)
"""
device = latent_model_input.device
dtype = latent_model_input.dtype
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
# Get number of frames from latent model input
_, _, num_latent_frames, _, _ = latent_model_input.shape
# Calculate original number of frames
# For n original frames, there are (n-1)//4 + 1 latent frames
# So to get n: n = (num_latent_frames-1)*4 + 1
num_frames = (num_latent_frames - 1) * 4 + 1
if len(first_frame.shape) == 3:
# we have a single image
first_frame = first_frame.unsqueeze(0)
# if it doesnt match the batch size, we need to expand it
if first_frame.shape[0] != latent_model_input.shape[0]:
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
# resize first frame to match the latent model input
vae_scale_factor = 8
first_frame = F.interpolate(
first_frame,
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
mode='bilinear',
align_corners=False
)
# Add temporal dimension to first frame
first_frame = first_frame.unsqueeze(2)
# Create video condition with first frame and zeros for remaining frames
zero_frame = torch.zeros_like(first_frame)
video_condition = torch.cat([
first_frame,
*[zero_frame for _ in range(num_frames - 1)]
], dim=2)
# Prepare for VAE encoding (bs, channels, num_frames, height, width)
# video_condition = video_condition.permute(0, 2, 1, 3, 4)
# Encode with VAE
latent_condition = vae.encode(
video_condition.to(device, dtype)
).latent_dist.sample()
latent_condition = latent_condition.to(device, dtype)
# Create mask: 1 for conditioning frames, 0 for frames to generate
batch_size = first_frame.shape[0]
latent_height = latent_condition.shape[3]
latent_width = latent_condition.shape[4]
# Initialize mask for all frames
mask_lat_size = torch.ones(
batch_size, 1, num_frames, latent_height, latent_width)
# Set all non-first frames to 0
mask_lat_size[:, :, list(range(1, num_frames))] = 0
# Special handling for first frame
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal)
# Combine first frame mask with rest
mask_lat_size = torch.concat(
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
# Reshape and transpose for model input
mask_lat_size = mask_lat_size.view(
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(device, dtype)
# Combine conditioning with latent input
first_frame_condition = torch.concat(
[mask_lat_size, latent_condition], dim=1)
conditioned_latent = torch.cat(
[latent_model_input, first_frame_condition], dim=1)
return conditioned_latent