mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Finished up first frame for i2v adapter
This commit is contained in:
@@ -1046,10 +1046,11 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
quad_count = random.randint(1, 4)
|
quad_count = random.randint(1, 4)
|
||||||
self.adapter.train()
|
self.adapter.train()
|
||||||
self.adapter.trigger_pre_te(
|
self.adapter.trigger_pre_te(
|
||||||
tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise
|
tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise
|
||||||
is_training=True,
|
is_training=True,
|
||||||
has_been_preprocessed=True,
|
has_been_preprocessed=True,
|
||||||
quad_count=quad_count,
|
quad_count=quad_count,
|
||||||
|
batch_tensor=batch.tensor if not is_reg else None,
|
||||||
batch_size=noisy_latents.shape[0]
|
batch_size=noisy_latents.shape[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
self.cached_control_image_0_1: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.setup_adapter()
|
self.setup_adapter()
|
||||||
|
|
||||||
@@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
def trigger_pre_te(
|
def trigger_pre_te(
|
||||||
self,
|
self,
|
||||||
tensors_0_1: torch.Tensor,
|
tensors_0_1: Optional[torch.Tensor]=None,
|
||||||
|
tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader
|
||||||
is_training=False,
|
is_training=False,
|
||||||
has_been_preprocessed=False,
|
has_been_preprocessed=False,
|
||||||
|
batch_tensor: Optional[torch.Tensor]=None,
|
||||||
quad_count=4,
|
quad_count=4,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
) -> PromptEmbeds:
|
) -> PromptEmbeds:
|
||||||
|
if tensors_0_1 is not None:
|
||||||
|
# actual 0 - 1 image
|
||||||
|
self.cached_control_image_0_1 = tensors_0_1
|
||||||
|
else:
|
||||||
|
# image has been processed through the dataloader and is prepped for vision encoder
|
||||||
|
self.cached_control_image_0_1 = None
|
||||||
|
if batch_tensor is not None and self.cached_control_image_0_1 is None:
|
||||||
|
# convert it to 0 - 1
|
||||||
|
to_cache = batch_tensor / 2 + 0.5
|
||||||
|
# videos come in (bs, num_frames, channels, height, width)
|
||||||
|
# images come in (bs, channels, height, width)
|
||||||
|
# if it is a video, just grad first frame
|
||||||
|
if len(to_cache.shape) == 5:
|
||||||
|
to_cache = to_cache[:, 0:1, :, :, :]
|
||||||
|
to_cache = to_cache.squeeze(1)
|
||||||
|
self.cached_control_image_0_1 = to_cache
|
||||||
|
|
||||||
|
if tensors_preprocessed is not None and has_been_preprocessed:
|
||||||
|
tensors_0_1 = tensors_preprocessed
|
||||||
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||||
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
|
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
|
||||||
skip_unconditional = self.sd_ref().is_flux
|
skip_unconditional = self.sd_ref().is_flux
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from diffusers import WanTransformer3DModel
|
|||||||
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
||||||
|
|
||||||
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -98,15 +97,18 @@ class FrameEmbedder(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
adapter: 'I2VAdapter',
|
adapter: 'I2VAdapter',
|
||||||
orig_layer: torch.nn.Linear,
|
orig_layer: torch.nn.Conv3d,
|
||||||
in_channels=64,
|
in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels
|
||||||
out_channels=3072
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# only do the weight for the new input. We combine with the original linear layer
|
# goes through a conv patch embedding first and is then flattened
|
||||||
init = torch.randn(out_channels, in_channels,
|
# hidden_states = self.patch_embedding(hidden_states)
|
||||||
device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01
|
# hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||||
self.weight = torch.nn.Parameter(init)
|
|
||||||
|
inner_dim = orig_layer.out_channels
|
||||||
|
patch_size = adapter.sd_ref().model.config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||||
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
||||||
@@ -116,35 +118,24 @@ class FrameEmbedder(torch.nn.Module):
|
|||||||
cls,
|
cls,
|
||||||
model: WanTransformer3DModel,
|
model: WanTransformer3DModel,
|
||||||
adapter: 'I2VAdapter',
|
adapter: 'I2VAdapter',
|
||||||
num_control_images=1,
|
|
||||||
has_inpainting_input=False
|
|
||||||
):
|
):
|
||||||
# TODO implement this
|
|
||||||
if model.__class__.__name__ == 'WanTransformer3DModel':
|
if model.__class__.__name__ == 'WanTransformer3DModel':
|
||||||
num_adapter_in_channels = model.x_embedder.in_features * num_control_images
|
new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels
|
||||||
|
|
||||||
if has_inpainting_input:
|
orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding
|
||||||
# inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask
|
|
||||||
# packed it is 64ch + 4ch mask
|
|
||||||
# so we need to add 4 to the input channels
|
|
||||||
num_adapter_in_channels += 4
|
|
||||||
|
|
||||||
x_embedder: torch.nn.Linear = model.x_embedder
|
|
||||||
img_embedder = cls(
|
img_embedder = cls(
|
||||||
adapter,
|
adapter,
|
||||||
orig_layer=x_embedder,
|
orig_layer=orig_patch_embedding,
|
||||||
in_channels=num_adapter_in_channels,
|
in_channels=new_channels,
|
||||||
out_channels=x_embedder.out_features,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# hijack the forward method
|
# hijack the forward method
|
||||||
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
|
orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward
|
||||||
x_embedder.forward = img_embedder.forward
|
orig_patch_embedding.forward = img_embedder.forward
|
||||||
|
|
||||||
# update the config of the transformer
|
# update the config of the transformer, only needed when merged in
|
||||||
model.config.in_channels = model.config.in_channels * \
|
# model.config.in_channels = model.config.in_channels + new_channels
|
||||||
(num_control_images + 1)
|
# model.config["in_channels"] = model.config.in_channels + new_channels
|
||||||
model.config["in_channels"] = model.config.in_channels
|
|
||||||
|
|
||||||
return img_embedder
|
return img_embedder
|
||||||
else:
|
else:
|
||||||
@@ -159,30 +150,37 @@ class FrameEmbedder(torch.nn.Module):
|
|||||||
# make sure lora is not active
|
# make sure lora is not active
|
||||||
if self.adapter_ref().control_lora is not None:
|
if self.adapter_ref().control_lora is not None:
|
||||||
self.adapter_ref().control_lora.is_active = False
|
self.adapter_ref().control_lora.is_active = False
|
||||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
|
||||||
|
if x.shape[1] > self.orig_layer_ref().in_channels:
|
||||||
|
# we have i2v, so we need to remove the extra channels
|
||||||
|
x = x[:, :self.orig_layer_ref().in_channels, :, :, :]
|
||||||
|
return self.orig_layer_ref()._orig_i2v_adapter_forward(x)
|
||||||
|
|
||||||
# make sure lora is active
|
# make sure lora is active
|
||||||
if self.adapter_ref().control_lora is not None:
|
if self.adapter_ref().control_lora is not None:
|
||||||
self.adapter_ref().control_lora.is_active = True
|
self.adapter_ref().control_lora.is_active = True
|
||||||
|
|
||||||
|
# x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16)
|
||||||
|
# (16 + 4 + 16) = 36 channels
|
||||||
|
# (batch_size, 36, num_frames, latent_height, latent_width)
|
||||||
|
|
||||||
orig_device = x.device
|
orig_device = x.device
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
|
|
||||||
|
orig_in = x[:, :16, :, :, :]
|
||||||
|
orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in)
|
||||||
|
|
||||||
|
# remove original stuff
|
||||||
|
x = x[:, 16:, :, :, :]
|
||||||
|
|
||||||
x = x.to(self.weight.device, dtype=self.weight.dtype)
|
x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype)
|
||||||
|
|
||||||
orig_weight = self.orig_layer_ref().weight.data.detach()
|
|
||||||
orig_weight = orig_weight.to(
|
|
||||||
self.weight.device, dtype=self.weight.dtype)
|
|
||||||
linear_weight = torch.cat([orig_weight, self.weight], dim=1)
|
|
||||||
|
|
||||||
bias = None
|
|
||||||
if self.orig_layer_ref().bias is not None:
|
|
||||||
bias = self.orig_layer_ref().bias.data.detach().to(
|
|
||||||
self.weight.device, dtype=self.weight.dtype)
|
|
||||||
|
|
||||||
x = torch.nn.functional.linear(x, linear_weight, bias)
|
|
||||||
|
|
||||||
|
x = self.patch_embedding(x)
|
||||||
|
|
||||||
x = x.to(orig_device, dtype=orig_dtype)
|
x = x.to(orig_device, dtype=orig_dtype)
|
||||||
|
|
||||||
|
# add the original out
|
||||||
|
x = x + orig_out
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +297,8 @@ def new_wan_forward(
|
|||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
# prevent circular import
|
||||||
|
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
|
||||||
adapter:'I2VAdapter' = self._i2v_adapter_ref()
|
adapter:'I2VAdapter' = self._i2v_adapter_ref()
|
||||||
|
|
||||||
if adapter.is_active:
|
if adapter.is_active:
|
||||||
@@ -336,6 +336,30 @@ def new_wan_forward(
|
|||||||
# doing a normal training run, always use conditional embeds
|
# doing a normal training run, always use conditional embeds
|
||||||
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
|
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
|
||||||
|
|
||||||
|
# add the first frame conditioning
|
||||||
|
if adapter.frame_embedder is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
# add the first frame conditioning
|
||||||
|
conditioning_frame = adapter.adapter_ref().cached_control_image_0_1
|
||||||
|
if conditioning_frame is None:
|
||||||
|
raise ValueError("No conditioning frame found")
|
||||||
|
|
||||||
|
# make it -1 to 1
|
||||||
|
conditioning_frame = (conditioning_frame * 2) - 1
|
||||||
|
conditioning_frame = conditioning_frame.to(
|
||||||
|
hidden_states.device, dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# if doing a full denoise, the latent input may be full channels here, only get first 16
|
||||||
|
if hidden_states.shape[1] > 16:
|
||||||
|
hidden_states = hidden_states[:, :16, :, :, :]
|
||||||
|
|
||||||
|
|
||||||
|
hidden_states = add_first_frame_conditioning(
|
||||||
|
latent_model_input=hidden_states,
|
||||||
|
first_frame=conditioning_frame,
|
||||||
|
vae=adapter.adapter_ref().sd_ref().vae,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# not active deactivate the condition embedder
|
# not active deactivate the condition embedder
|
||||||
self.condition_embedder.image_embedder = None
|
self.condition_embedder.image_embedder = None
|
||||||
@@ -450,9 +474,7 @@ class I2VAdapter(torch.nn.Module):
|
|||||||
if self.config.i2v_do_start_frame:
|
if self.config.i2v_do_start_frame:
|
||||||
self.frame_embedder = FrameEmbedder.from_model(
|
self.frame_embedder = FrameEmbedder.from_model(
|
||||||
sd.unet,
|
sd.unet,
|
||||||
self,
|
self
|
||||||
num_control_images=config.num_control_images,
|
|
||||||
has_inpainting_input=config.has_inpainting_input
|
|
||||||
)
|
)
|
||||||
self.frame_embedder.to(self.device_torch)
|
self.frame_embedder.to(self.device_torch)
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ from .wan21 import \
|
|||||||
scheduler_config, \
|
scheduler_config, \
|
||||||
Wan21
|
Wan21
|
||||||
|
|
||||||
|
from .wan_utils import add_first_frame_conditioning
|
||||||
|
|
||||||
|
|
||||||
class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -269,6 +271,9 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()}
|
image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()}
|
||||||
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
||||||
return image_embeds.hidden_states[-2]
|
return image_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Wan21I2V(Wan21):
|
class Wan21I2V(Wan21):
|
||||||
@@ -488,46 +493,12 @@ class Wan21I2V(Wan21):
|
|||||||
image_embeds = image_embeds_full.hidden_states[-2]
|
image_embeds = image_embeds_full.hidden_states[-2]
|
||||||
image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype)
|
image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# condition latent
|
# Add conditioning using the standalone function
|
||||||
# first_frames shape is (bs, channels, height, width)
|
conditioned_latent = add_first_frame_conditioning(
|
||||||
# wan needs latends in (bs, channels, num_frames, height, width)
|
latent_model_input=latent_model_input,
|
||||||
first_frames = first_frames.unsqueeze(2)
|
first_frame=first_frames,
|
||||||
# video condition is first frame is the frame, the rest are zeros
|
vae=self.vae
|
||||||
num_frames = frames.shape[1]
|
|
||||||
|
|
||||||
zero_frame = torch.zeros_like(first_frames)
|
|
||||||
video_condition = torch.cat([
|
|
||||||
first_frames,
|
|
||||||
*[zero_frame for _ in range(num_frames - 1)]
|
|
||||||
], dim=2)
|
|
||||||
|
|
||||||
# our vae encoder expects (bs, num_frames, channels, height, width)
|
|
||||||
# permute to (bs, channels, num_frames, height, width)
|
|
||||||
video_condition = video_condition.permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
latent_condition = self.encode_images(
|
|
||||||
video_condition,
|
|
||||||
device=self.device_torch,
|
|
||||||
dtype=self.torch_dtype,
|
|
||||||
)
|
)
|
||||||
latent_condition = latent_condition.to(self.device_torch, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
batch_size = frames.shape[0]
|
|
||||||
latent_height = latent_condition.shape[3]
|
|
||||||
latent_width = latent_condition.shape[4]
|
|
||||||
|
|
||||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
|
||||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
|
||||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
|
||||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.pipeline.vae_scale_factor_temporal)
|
|
||||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
|
||||||
mask_lat_size = mask_lat_size.view(batch_size, -1, self.pipeline.vae_scale_factor_temporal, latent_height, latent_width)
|
|
||||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
|
||||||
mask_lat_size = mask_lat_size.to(self.device_torch, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
# return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
|
||||||
first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
|
|
||||||
conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1)
|
|
||||||
|
|
||||||
noise_pred = self.model(
|
noise_pred = self.model(
|
||||||
hidden_states=conditioned_latent,
|
hidden_states=conditioned_latent,
|
||||||
@@ -537,4 +508,4 @@ class Wan21I2V(Wan21):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
)[0]
|
)[0]
|
||||||
return noise_pred
|
return noise_pred
|
||||||
102
toolkit/models/wan21/wan_utils.py
Normal file
102
toolkit/models/wan21/wan_utils.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def add_first_frame_conditioning(
|
||||||
|
latent_model_input,
|
||||||
|
first_frame,
|
||||||
|
vae
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds first frame conditioning to a video diffusion model input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_model_input: Original latent input (bs, channels, num_frames, height, width)
|
||||||
|
first_frame: Tensor of first frame to condition on (bs, channels, height, width)
|
||||||
|
vae: VAE model for encoding the conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width)
|
||||||
|
"""
|
||||||
|
device = latent_model_input.device
|
||||||
|
dtype = latent_model_input.dtype
|
||||||
|
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
|
||||||
|
|
||||||
|
# Get number of frames from latent model input
|
||||||
|
_, _, num_latent_frames, _, _ = latent_model_input.shape
|
||||||
|
|
||||||
|
# Calculate original number of frames
|
||||||
|
# For n original frames, there are (n-1)//4 + 1 latent frames
|
||||||
|
# So to get n: n = (num_latent_frames-1)*4 + 1
|
||||||
|
num_frames = (num_latent_frames - 1) * 4 + 1
|
||||||
|
|
||||||
|
if len(first_frame.shape) == 3:
|
||||||
|
# we have a single image
|
||||||
|
first_frame = first_frame.unsqueeze(0)
|
||||||
|
|
||||||
|
# if it doesnt match the batch size, we need to expand it
|
||||||
|
if first_frame.shape[0] != latent_model_input.shape[0]:
|
||||||
|
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
||||||
|
|
||||||
|
# resize first frame to match the latent model input
|
||||||
|
vae_scale_factor = 8
|
||||||
|
first_frame = F.interpolate(
|
||||||
|
first_frame,
|
||||||
|
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add temporal dimension to first frame
|
||||||
|
first_frame = first_frame.unsqueeze(2)
|
||||||
|
|
||||||
|
# Create video condition with first frame and zeros for remaining frames
|
||||||
|
zero_frame = torch.zeros_like(first_frame)
|
||||||
|
video_condition = torch.cat([
|
||||||
|
first_frame,
|
||||||
|
*[zero_frame for _ in range(num_frames - 1)]
|
||||||
|
], dim=2)
|
||||||
|
|
||||||
|
# Prepare for VAE encoding (bs, channels, num_frames, height, width)
|
||||||
|
# video_condition = video_condition.permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
|
# Encode with VAE
|
||||||
|
latent_condition = vae.encode(
|
||||||
|
video_condition.to(device, dtype)
|
||||||
|
).latent_dist.sample()
|
||||||
|
latent_condition = latent_condition.to(device, dtype)
|
||||||
|
|
||||||
|
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
||||||
|
batch_size = first_frame.shape[0]
|
||||||
|
latent_height = latent_condition.shape[3]
|
||||||
|
latent_width = latent_condition.shape[4]
|
||||||
|
|
||||||
|
# Initialize mask for all frames
|
||||||
|
mask_lat_size = torch.ones(
|
||||||
|
batch_size, 1, num_frames, latent_height, latent_width)
|
||||||
|
|
||||||
|
# Set all non-first frames to 0
|
||||||
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||||
|
|
||||||
|
# Special handling for first frame
|
||||||
|
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||||
|
first_frame_mask = torch.repeat_interleave(
|
||||||
|
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal)
|
||||||
|
|
||||||
|
# Combine first frame mask with rest
|
||||||
|
mask_lat_size = torch.concat(
|
||||||
|
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||||
|
|
||||||
|
# Reshape and transpose for model input
|
||||||
|
mask_lat_size = mask_lat_size.view(
|
||||||
|
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width)
|
||||||
|
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||||
|
mask_lat_size = mask_lat_size.to(device, dtype)
|
||||||
|
|
||||||
|
# Combine conditioning with latent input
|
||||||
|
first_frame_condition = torch.concat(
|
||||||
|
[mask_lat_size, latent_condition], dim=1)
|
||||||
|
conditioned_latent = torch.cat(
|
||||||
|
[latent_model_input, first_frame_condition], dim=1)
|
||||||
|
|
||||||
|
return conditioned_latent
|
||||||
Reference in New Issue
Block a user