mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Finished up first frame for i2v adapter
This commit is contained in:
@@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
self.cached_control_image_0_1: Optional[torch.Tensor] = None
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
@@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
tensors_0_1: Optional[torch.Tensor]=None,
|
||||
tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
batch_tensor: Optional[torch.Tensor]=None,
|
||||
quad_count=4,
|
||||
batch_size=1,
|
||||
) -> PromptEmbeds:
|
||||
if tensors_0_1 is not None:
|
||||
# actual 0 - 1 image
|
||||
self.cached_control_image_0_1 = tensors_0_1
|
||||
else:
|
||||
# image has been processed through the dataloader and is prepped for vision encoder
|
||||
self.cached_control_image_0_1 = None
|
||||
if batch_tensor is not None and self.cached_control_image_0_1 is None:
|
||||
# convert it to 0 - 1
|
||||
to_cache = batch_tensor / 2 + 0.5
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
# if it is a video, just grad first frame
|
||||
if len(to_cache.shape) == 5:
|
||||
to_cache = to_cache[:, 0:1, :, :, :]
|
||||
to_cache = to_cache.squeeze(1)
|
||||
self.cached_control_image_0_1 = to_cache
|
||||
|
||||
if tensors_preprocessed is not None and has_been_preprocessed:
|
||||
tensors_0_1 = tensors_preprocessed
|
||||
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
|
||||
skip_unconditional = self.sd_ref().is_flux
|
||||
|
||||
@@ -9,7 +9,6 @@ from diffusers import WanTransformer3DModel
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
||||
|
||||
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -98,15 +97,18 @@ class FrameEmbedder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'I2VAdapter',
|
||||
orig_layer: torch.nn.Linear,
|
||||
in_channels=64,
|
||||
out_channels=3072
|
||||
orig_layer: torch.nn.Conv3d,
|
||||
in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels
|
||||
):
|
||||
super().__init__()
|
||||
# only do the weight for the new input. We combine with the original linear layer
|
||||
init = torch.randn(out_channels, in_channels,
|
||||
device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01
|
||||
self.weight = torch.nn.Parameter(init)
|
||||
# goes through a conv patch embedding first and is then flattened
|
||||
# hidden_states = self.patch_embedding(hidden_states)
|
||||
# hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
inner_dim = orig_layer.out_channels
|
||||
patch_size = adapter.sd_ref().model.config.patch_size
|
||||
|
||||
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
||||
@@ -116,35 +118,24 @@ class FrameEmbedder(torch.nn.Module):
|
||||
cls,
|
||||
model: WanTransformer3DModel,
|
||||
adapter: 'I2VAdapter',
|
||||
num_control_images=1,
|
||||
has_inpainting_input=False
|
||||
):
|
||||
# TODO implement this
|
||||
if model.__class__.__name__ == 'WanTransformer3DModel':
|
||||
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
|
||||
new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels
|
||||
|
||||
if has_inpainting_input:
|
||||
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
|
||||
# packed it is 64ch + 4ch mask
|
||||
# so we need to add 4 to the input channels
|
||||
num_adapter_in_channels += 4
|
||||
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=num_adapter_in_channels,
|
||||
out_channels=x_embedder.out_features,
|
||||
orig_layer=orig_patch_embedding,
|
||||
in_channels=new_channels,
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
|
||||
x_embedder.forward = img_embedder.forward
|
||||
orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward
|
||||
orig_patch_embedding.forward = img_embedder.forward
|
||||
|
||||
# update the config of the transformer
|
||||
model.config.in_channels = model.config.in_channels * \
|
||||
(num_control_images + 1)
|
||||
model.config["in_channels"] = model.config.in_channels
|
||||
# update the config of the transformer, only needed when merged in
|
||||
# model.config.in_channels = model.config.in_channels + new_channels
|
||||
# model.config["in_channels"] = model.config.in_channels + new_channels
|
||||
|
||||
return img_embedder
|
||||
else:
|
||||
@@ -159,30 +150,37 @@ class FrameEmbedder(torch.nn.Module):
|
||||
# make sure lora is not active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = False
|
||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
||||
|
||||
if x.shape[1] > self.orig_layer_ref().in_channels:
|
||||
# we have i2v, so we need to remove the extra channels
|
||||
x = x[:, :self.orig_layer_ref().in_channels, :, :, :]
|
||||
return self.orig_layer_ref()._orig_i2v_adapter_forward(x)
|
||||
|
||||
# make sure lora is active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = True
|
||||
|
||||
# x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16)
|
||||
# (16 + 4 + 16) = 36 channels
|
||||
# (batch_size, 36, num_frames, latent_height, latent_width)
|
||||
|
||||
orig_device = x.device
|
||||
orig_dtype = x.dtype
|
||||
|
||||
orig_in = x[:, :16, :, :, :]
|
||||
orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in)
|
||||
|
||||
# remove original stuff
|
||||
x = x[:, 16:, :, :, :]
|
||||
|
||||
x = x.to(self.weight.device, dtype=self.weight.dtype)
|
||||
|
||||
orig_weight = self.orig_layer_ref().weight.data.detach()
|
||||
orig_weight = orig_weight.to(
|
||||
self.weight.device, dtype=self.weight.dtype)
|
||||
linear_weight = torch.cat([orig_weight, self.weight], dim=1)
|
||||
|
||||
bias = None
|
||||
if self.orig_layer_ref().bias is not None:
|
||||
bias = self.orig_layer_ref().bias.data.detach().to(
|
||||
self.weight.device, dtype=self.weight.dtype)
|
||||
|
||||
x = torch.nn.functional.linear(x, linear_weight, bias)
|
||||
x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype)
|
||||
|
||||
x = self.patch_embedding(x)
|
||||
|
||||
x = x.to(orig_device, dtype=orig_dtype)
|
||||
|
||||
# add the original out
|
||||
x = x + orig_out
|
||||
return x
|
||||
|
||||
|
||||
@@ -299,6 +297,8 @@ def new_wan_forward(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
# prevent circular import
|
||||
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
|
||||
adapter:'I2VAdapter' = self._i2v_adapter_ref()
|
||||
|
||||
if adapter.is_active:
|
||||
@@ -336,6 +336,30 @@ def new_wan_forward(
|
||||
# doing a normal training run, always use conditional embeds
|
||||
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
|
||||
|
||||
# add the first frame conditioning
|
||||
if adapter.frame_embedder is not None:
|
||||
with torch.no_grad():
|
||||
# add the first frame conditioning
|
||||
conditioning_frame = adapter.adapter_ref().cached_control_image_0_1
|
||||
if conditioning_frame is None:
|
||||
raise ValueError("No conditioning frame found")
|
||||
|
||||
# make it -1 to 1
|
||||
conditioning_frame = (conditioning_frame * 2) - 1
|
||||
conditioning_frame = conditioning_frame.to(
|
||||
hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# if doing a full denoise, the latent input may be full channels here, only get first 16
|
||||
if hidden_states.shape[1] > 16:
|
||||
hidden_states = hidden_states[:, :16, :, :, :]
|
||||
|
||||
|
||||
hidden_states = add_first_frame_conditioning(
|
||||
latent_model_input=hidden_states,
|
||||
first_frame=conditioning_frame,
|
||||
vae=adapter.adapter_ref().sd_ref().vae,
|
||||
)
|
||||
else:
|
||||
# not active deactivate the condition embedder
|
||||
self.condition_embedder.image_embedder = None
|
||||
@@ -450,9 +474,7 @@ class I2VAdapter(torch.nn.Module):
|
||||
if self.config.i2v_do_start_frame:
|
||||
self.frame_embedder = FrameEmbedder.from_model(
|
||||
sd.unet,
|
||||
self,
|
||||
num_control_images=config.num_control_images,
|
||||
has_inpainting_input=config.has_inpainting_input
|
||||
self
|
||||
)
|
||||
self.frame_embedder.to(self.device_torch)
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ from .wan21 import \
|
||||
scheduler_config, \
|
||||
Wan21
|
||||
|
||||
from .wan_utils import add_first_frame_conditioning
|
||||
|
||||
|
||||
class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
def __init__(
|
||||
@@ -269,6 +271,9 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()}
|
||||
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Wan21I2V(Wan21):
|
||||
@@ -488,46 +493,12 @@ class Wan21I2V(Wan21):
|
||||
image_embeds = image_embeds_full.hidden_states[-2]
|
||||
image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
# condition latent
|
||||
# first_frames shape is (bs, channels, height, width)
|
||||
# wan needs latends in (bs, channels, num_frames, height, width)
|
||||
first_frames = first_frames.unsqueeze(2)
|
||||
# video condition is first frame is the frame, the rest are zeros
|
||||
num_frames = frames.shape[1]
|
||||
|
||||
zero_frame = torch.zeros_like(first_frames)
|
||||
video_condition = torch.cat([
|
||||
first_frames,
|
||||
*[zero_frame for _ in range(num_frames - 1)]
|
||||
], dim=2)
|
||||
|
||||
# our vae encoder expects (bs, num_frames, channels, height, width)
|
||||
# permute to (bs, channels, num_frames, height, width)
|
||||
video_condition = video_condition.permute(0, 2, 1, 3, 4)
|
||||
|
||||
latent_condition = self.encode_images(
|
||||
video_condition,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
# Add conditioning using the standalone function
|
||||
conditioned_latent = add_first_frame_conditioning(
|
||||
latent_model_input=latent_model_input,
|
||||
first_frame=first_frames,
|
||||
vae=self.vae
|
||||
)
|
||||
latent_condition = latent_condition.to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
batch_size = frames.shape[0]
|
||||
latent_height = latent_condition.shape[3]
|
||||
latent_width = latent_condition.shape[4]
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.pipeline.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(batch_size, -1, self.pipeline.vae_scale_factor_temporal, latent_height, latent_width)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
# return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||
first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||
conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1)
|
||||
|
||||
noise_pred = self.model(
|
||||
hidden_states=conditioned_latent,
|
||||
@@ -537,4 +508,4 @@ class Wan21I2V(Wan21):
|
||||
return_dict=False,
|
||||
**kwargs
|
||||
)[0]
|
||||
return noise_pred
|
||||
return noise_pred
|
||||
102
toolkit/models/wan21/wan_utils.py
Normal file
102
toolkit/models/wan21/wan_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def add_first_frame_conditioning(
|
||||
latent_model_input,
|
||||
first_frame,
|
||||
vae
|
||||
):
|
||||
"""
|
||||
Adds first frame conditioning to a video diffusion model input.
|
||||
|
||||
Args:
|
||||
latent_model_input: Original latent input (bs, channels, num_frames, height, width)
|
||||
first_frame: Tensor of first frame to condition on (bs, channels, height, width)
|
||||
vae: VAE model for encoding the conditioning
|
||||
|
||||
Returns:
|
||||
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width)
|
||||
"""
|
||||
device = latent_model_input.device
|
||||
dtype = latent_model_input.dtype
|
||||
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
|
||||
|
||||
# Get number of frames from latent model input
|
||||
_, _, num_latent_frames, _, _ = latent_model_input.shape
|
||||
|
||||
# Calculate original number of frames
|
||||
# For n original frames, there are (n-1)//4 + 1 latent frames
|
||||
# So to get n: n = (num_latent_frames-1)*4 + 1
|
||||
num_frames = (num_latent_frames - 1) * 4 + 1
|
||||
|
||||
if len(first_frame.shape) == 3:
|
||||
# we have a single image
|
||||
first_frame = first_frame.unsqueeze(0)
|
||||
|
||||
# if it doesnt match the batch size, we need to expand it
|
||||
if first_frame.shape[0] != latent_model_input.shape[0]:
|
||||
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
||||
|
||||
# resize first frame to match the latent model input
|
||||
vae_scale_factor = 8
|
||||
first_frame = F.interpolate(
|
||||
first_frame,
|
||||
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Add temporal dimension to first frame
|
||||
first_frame = first_frame.unsqueeze(2)
|
||||
|
||||
# Create video condition with first frame and zeros for remaining frames
|
||||
zero_frame = torch.zeros_like(first_frame)
|
||||
video_condition = torch.cat([
|
||||
first_frame,
|
||||
*[zero_frame for _ in range(num_frames - 1)]
|
||||
], dim=2)
|
||||
|
||||
# Prepare for VAE encoding (bs, channels, num_frames, height, width)
|
||||
# video_condition = video_condition.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Encode with VAE
|
||||
latent_condition = vae.encode(
|
||||
video_condition.to(device, dtype)
|
||||
).latent_dist.sample()
|
||||
latent_condition = latent_condition.to(device, dtype)
|
||||
|
||||
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
||||
batch_size = first_frame.shape[0]
|
||||
latent_height = latent_condition.shape[3]
|
||||
latent_width = latent_condition.shape[4]
|
||||
|
||||
# Initialize mask for all frames
|
||||
mask_lat_size = torch.ones(
|
||||
batch_size, 1, num_frames, latent_height, latent_width)
|
||||
|
||||
# Set all non-first frames to 0
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
|
||||
# Special handling for first frame
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal)
|
||||
|
||||
# Combine first frame mask with rest
|
||||
mask_lat_size = torch.concat(
|
||||
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
|
||||
# Reshape and transpose for model input
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(device, dtype)
|
||||
|
||||
# Combine conditioning with latent input
|
||||
first_frame_condition = torch.concat(
|
||||
[mask_lat_size, latent_condition], dim=1)
|
||||
conditioned_latent = torch.cat(
|
||||
[latent_model_input, first_frame_condition], dim=1)
|
||||
|
||||
return conditioned_latent
|
||||
Reference in New Issue
Block a user