From 755f0e207c327cc7fabf16569702e29ddb1103fa Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 12 Jul 2025 16:56:27 -0600 Subject: [PATCH] Fix issue with wan i2v scaling. Adjust aggressive loader to be compatable with updated diffusers. --- toolkit/models/wan21/wan21_i2v.py | 24 ++++++++++++++++-------- toolkit/models/wan21/wan_utils.py | 11 +++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index b2b3afd6..8b9f2918 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -99,7 +99,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # unload vae and transformer - device = self.transformer.device + # device = self.transformer.device + device = self._exec_device self.text_encoder.to(device) @@ -116,7 +117,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): width, prompt_embeds, negative_prompt_embeds, - callback_on_step_end_tensor_inputs, + image_embeds=None, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale @@ -145,7 +147,6 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): ) # unload text encoder - print("Unloading text encoder") self.text_encoder.to("cpu") self.transformer.to(device) flush() @@ -456,13 +457,20 @@ class Wan21I2V(Wan21): # Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL) if tensor.shape[2] != 224 or tensor.shape[3] != 224: tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False) + + tensors_0_1 = tensor.clamp(0, 1) # Ensure values are in [0, 1] range - # Normalize with mean and std - mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(tensor.device) - std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(tensor.device) - tensor = (tensor - mean) / std + mean = torch.tensor(self.image_processor.image_mean).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + std = torch.tensor(self.image_processor.image_std).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() - return tensor + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean) / std + + return clip_image.detach() def get_noise_prediction( self, diff --git a/toolkit/models/wan21/wan_utils.py b/toolkit/models/wan21/wan_utils.py index 1f837ce6..422cf3f7 100644 --- a/toolkit/models/wan21/wan_utils.py +++ b/toolkit/models/wan21/wan_utils.py @@ -65,6 +65,17 @@ def add_first_frame_conditioning( video_condition.to(device, dtype) ).latent_dist.sample() latent_condition = latent_condition.to(device, dtype) + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + # Create mask: 1 for conditioning frames, 0 for frames to generate batch_size = first_frame.shape[0]