diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 4840ed84..5bf17441 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1046,10 +1046,11 @@ class SDTrainer(BaseSDTrainProcess): quad_count = random.randint(1, 4) self.adapter.train() self.adapter.trigger_pre_te( - tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise + tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise is_training=True, has_been_preprocessed=True, quad_count=quad_count, + batch_tensor=batch.tensor if not is_reg else None, batch_size=noisy_latents.shape[0] ) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 5761bf03..820fdc39 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module): self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None + + self.cached_control_image_0_1: Optional[torch.Tensor] = None self.setup_adapter() @@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module): def trigger_pre_te( self, - tensors_0_1: torch.Tensor, + tensors_0_1: Optional[torch.Tensor]=None, + tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader is_training=False, has_been_preprocessed=False, + batch_tensor: Optional[torch.Tensor]=None, quad_count=4, batch_size=1, ) -> PromptEmbeds: + if tensors_0_1 is not None: + # actual 0 - 1 image + self.cached_control_image_0_1 = tensors_0_1 + else: + # image has been processed through the dataloader and is prepped for vision encoder + self.cached_control_image_0_1 = None + if batch_tensor is not None and self.cached_control_image_0_1 is None: + # convert it to 0 - 1 + to_cache = batch_tensor / 2 + 0.5 + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + # if it is a video, just grad first frame + if len(to_cache.shape) == 5: + to_cache = to_cache[:, 0:1, :, :, :] + to_cache = to_cache.squeeze(1) + self.cached_control_image_0_1 = to_cache + + if tensors_preprocessed is not None and has_been_preprocessed: + tensors_0_1 = tensors_preprocessed # if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']: skip_unconditional = self.sd_ref().is_flux diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py index 50b24cd4..d51efc69 100644 --- a/toolkit/models/i2v_adapter.py +++ b/toolkit/models/i2v_adapter.py @@ -9,7 +9,6 @@ from diffusers import WanTransformer3DModel from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding - from toolkit.util.shuffle import shuffle_tensor_along_axis if TYPE_CHECKING: @@ -98,15 +97,18 @@ class FrameEmbedder(torch.nn.Module): def __init__( self, adapter: 'I2VAdapter', - orig_layer: torch.nn.Linear, - in_channels=64, - out_channels=3072 + orig_layer: torch.nn.Conv3d, + in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels ): super().__init__() - # only do the weight for the new input. We combine with the original linear layer - init = torch.randn(out_channels, in_channels, - device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01 - self.weight = torch.nn.Parameter(init) + # goes through a conv patch embedding first and is then flattened + # hidden_states = self.patch_embedding(hidden_states) + # hidden_states = hidden_states.flatten(2).transpose(1, 2) + + inner_dim = orig_layer.out_channels + patch_size = adapter.sd_ref().model.config.patch_size + + self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) self.adapter_ref: weakref.ref = weakref.ref(adapter) self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) @@ -116,35 +118,24 @@ class FrameEmbedder(torch.nn.Module): cls, model: WanTransformer3DModel, adapter: 'I2VAdapter', - num_control_images=1, - has_inpainting_input=False ): - # TODO implement this if model.__class__.__name__ == 'WanTransformer3DModel': - num_adapter_in_channels = model.x_embedder.in_features * num_control_images + new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels - if has_inpainting_input: - # inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask - # packed it is 64ch + 4ch mask - # so we need to add 4 to the input channels - num_adapter_in_channels += 4 - - x_embedder: torch.nn.Linear = model.x_embedder + orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding img_embedder = cls( adapter, - orig_layer=x_embedder, - in_channels=num_adapter_in_channels, - out_channels=x_embedder.out_features, + orig_layer=orig_patch_embedding, + in_channels=new_channels, ) # hijack the forward method - x_embedder._orig_ctrl_lora_forward = x_embedder.forward - x_embedder.forward = img_embedder.forward + orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward + orig_patch_embedding.forward = img_embedder.forward - # update the config of the transformer - model.config.in_channels = model.config.in_channels * \ - (num_control_images + 1) - model.config["in_channels"] = model.config.in_channels + # update the config of the transformer, only needed when merged in + # model.config.in_channels = model.config.in_channels + new_channels + # model.config["in_channels"] = model.config.in_channels + new_channels return img_embedder else: @@ -159,30 +150,37 @@ class FrameEmbedder(torch.nn.Module): # make sure lora is not active if self.adapter_ref().control_lora is not None: self.adapter_ref().control_lora.is_active = False - return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + if x.shape[1] > self.orig_layer_ref().in_channels: + # we have i2v, so we need to remove the extra channels + x = x[:, :self.orig_layer_ref().in_channels, :, :, :] + return self.orig_layer_ref()._orig_i2v_adapter_forward(x) # make sure lora is active if self.adapter_ref().control_lora is not None: self.adapter_ref().control_lora.is_active = True + + # x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16) + # (16 + 4 + 16) = 36 channels + # (batch_size, 36, num_frames, latent_height, latent_width) orig_device = x.device orig_dtype = x.dtype + + orig_in = x[:, :16, :, :, :] + orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in) + + # remove original stuff + x = x[:, 16:, :, :, :] - x = x.to(self.weight.device, dtype=self.weight.dtype) - - orig_weight = self.orig_layer_ref().weight.data.detach() - orig_weight = orig_weight.to( - self.weight.device, dtype=self.weight.dtype) - linear_weight = torch.cat([orig_weight, self.weight], dim=1) - - bias = None - if self.orig_layer_ref().bias is not None: - bias = self.orig_layer_ref().bias.data.detach().to( - self.weight.device, dtype=self.weight.dtype) - - x = torch.nn.functional.linear(x, linear_weight, bias) + x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype) + x = self.patch_embedding(x) + x = x.to(orig_device, dtype=orig_dtype) + + # add the original out + x = x + orig_out return x @@ -299,6 +297,8 @@ def new_wan_forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # prevent circular import + from toolkit.models.wan21.wan_utils import add_first_frame_conditioning adapter:'I2VAdapter' = self._i2v_adapter_ref() if adapter.is_active: @@ -336,6 +336,30 @@ def new_wan_forward( # doing a normal training run, always use conditional embeds encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds + # add the first frame conditioning + if adapter.frame_embedder is not None: + with torch.no_grad(): + # add the first frame conditioning + conditioning_frame = adapter.adapter_ref().cached_control_image_0_1 + if conditioning_frame is None: + raise ValueError("No conditioning frame found") + + # make it -1 to 1 + conditioning_frame = (conditioning_frame * 2) - 1 + conditioning_frame = conditioning_frame.to( + hidden_states.device, dtype=hidden_states.dtype + ) + + # if doing a full denoise, the latent input may be full channels here, only get first 16 + if hidden_states.shape[1] > 16: + hidden_states = hidden_states[:, :16, :, :, :] + + + hidden_states = add_first_frame_conditioning( + latent_model_input=hidden_states, + first_frame=conditioning_frame, + vae=adapter.adapter_ref().sd_ref().vae, + ) else: # not active deactivate the condition embedder self.condition_embedder.image_embedder = None @@ -450,9 +474,7 @@ class I2VAdapter(torch.nn.Module): if self.config.i2v_do_start_frame: self.frame_embedder = FrameEmbedder.from_model( sd.unet, - self, - num_control_images=config.num_control_images, - has_inpainting_input=config.has_inpainting_input + self ) self.frame_embedder.to(self.device_torch) diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index 3a7a0181..2bd23b2a 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -41,6 +41,8 @@ from .wan21 import \ scheduler_config, \ Wan21 +from .wan_utils import add_first_frame_conditioning + class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): def __init__( @@ -269,6 +271,9 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()} image_embeds = self.image_encoder(**image, output_hidden_states=True) return image_embeds.hidden_states[-2] + + + class Wan21I2V(Wan21): @@ -488,46 +493,12 @@ class Wan21I2V(Wan21): image_embeds = image_embeds_full.hidden_states[-2] image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype) - # condition latent - # first_frames shape is (bs, channels, height, width) - # wan needs latends in (bs, channels, num_frames, height, width) - first_frames = first_frames.unsqueeze(2) - # video condition is first frame is the frame, the rest are zeros - num_frames = frames.shape[1] - - zero_frame = torch.zeros_like(first_frames) - video_condition = torch.cat([ - first_frames, - *[zero_frame for _ in range(num_frames - 1)] - ], dim=2) - - # our vae encoder expects (bs, num_frames, channels, height, width) - # permute to (bs, channels, num_frames, height, width) - video_condition = video_condition.permute(0, 2, 1, 3, 4) - - latent_condition = self.encode_images( - video_condition, - device=self.device_torch, - dtype=self.torch_dtype, + # Add conditioning using the standalone function + conditioned_latent = add_first_frame_conditioning( + latent_model_input=latent_model_input, + first_frame=first_frames, + vae=self.vae ) - latent_condition = latent_condition.to(self.device_torch, dtype=self.torch_dtype) - - batch_size = frames.shape[0] - latent_height = latent_condition.shape[3] - latent_width = latent_condition.shape[4] - - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.pipeline.vae_scale_factor_temporal) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.pipeline.vae_scale_factor_temporal, latent_height, latent_width) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(self.device_torch, dtype=self.torch_dtype) - - # return latents, torch.concat([mask_lat_size, latent_condition], dim=1) - first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1) - conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1) noise_pred = self.model( hidden_states=conditioned_latent, @@ -537,4 +508,4 @@ class Wan21I2V(Wan21): return_dict=False, **kwargs )[0] - return noise_pred + return noise_pred \ No newline at end of file diff --git a/toolkit/models/wan21/wan_utils.py b/toolkit/models/wan21/wan_utils.py new file mode 100644 index 00000000..1f837ce6 --- /dev/null +++ b/toolkit/models/wan21/wan_utils.py @@ -0,0 +1,102 @@ +import torch +import torch.nn.functional as F + + +def add_first_frame_conditioning( + latent_model_input, + first_frame, + vae +): + """ + Adds first frame conditioning to a video diffusion model input. + + Args: + latent_model_input: Original latent input (bs, channels, num_frames, height, width) + first_frame: Tensor of first frame to condition on (bs, channels, height, width) + vae: VAE model for encoding the conditioning + + Returns: + conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width) + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample) + + # Get number of frames from latent model input + _, _, num_latent_frames, _, _ = latent_model_input.shape + + # Calculate original number of frames + # For n original frames, there are (n-1)//4 + 1 latent frames + # So to get n: n = (num_latent_frames-1)*4 + 1 + num_frames = (num_latent_frames - 1) * 4 + 1 + + if len(first_frame.shape) == 3: + # we have a single image + first_frame = first_frame.unsqueeze(0) + + # if it doesnt match the batch size, we need to expand it + if first_frame.shape[0] != latent_model_input.shape[0]: + first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) + + # resize first frame to match the latent model input + vae_scale_factor = 8 + first_frame = F.interpolate( + first_frame, + size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), + mode='bilinear', + align_corners=False + ) + + # Add temporal dimension to first frame + first_frame = first_frame.unsqueeze(2) + + # Create video condition with first frame and zeros for remaining frames + zero_frame = torch.zeros_like(first_frame) + video_condition = torch.cat([ + first_frame, + *[zero_frame for _ in range(num_frames - 1)] + ], dim=2) + + # Prepare for VAE encoding (bs, channels, num_frames, height, width) + # video_condition = video_condition.permute(0, 2, 1, 3, 4) + + # Encode with VAE + latent_condition = vae.encode( + video_condition.to(device, dtype) + ).latent_dist.sample() + latent_condition = latent_condition.to(device, dtype) + + # Create mask: 1 for conditioning frames, 0 for frames to generate + batch_size = first_frame.shape[0] + latent_height = latent_condition.shape[3] + latent_width = latent_condition.shape[4] + + # Initialize mask for all frames + mask_lat_size = torch.ones( + batch_size, 1, num_frames, latent_height, latent_width) + + # Set all non-first frames to 0 + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + + # Special handling for first frame + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) + + # Combine first frame mask with rest + mask_lat_size = torch.concat( + [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + + # Reshape and transpose for model input + mask_lat_size = mask_lat_size.view( + batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(device, dtype) + + # Combine conditioning with latent input + first_frame_condition = torch.concat( + [mask_lat_size, latent_condition], dim=1) + conditioned_latent = torch.cat( + [latent_model_input, first_frame_condition], dim=1) + + return conditioned_latent