Fix issue with wan i2v scaling. Adjust aggressive loader to be compatable with updated diffusers.

This commit is contained in:
Jaret Burkett
2025-07-12 16:56:27 -06:00
parent 2e84b3d5b1
commit 755f0e207c
2 changed files with 27 additions and 8 deletions

View File

@@ -99,7 +99,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# unload vae and transformer # unload vae and transformer
device = self.transformer.device # device = self.transformer.device
device = self._exec_device
self.text_encoder.to(device) self.text_encoder.to(device)
@@ -116,7 +117,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
width, width,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
callback_on_step_end_tensor_inputs, image_embeds=None,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
@@ -145,7 +147,6 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
) )
# unload text encoder # unload text encoder
print("Unloading text encoder")
self.text_encoder.to("cpu") self.text_encoder.to("cpu")
self.transformer.to(device) self.transformer.to(device)
flush() flush()
@@ -456,13 +457,20 @@ class Wan21I2V(Wan21):
# Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL) # Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL)
if tensor.shape[2] != 224 or tensor.shape[3] != 224: if tensor.shape[2] != 224 or tensor.shape[3] != 224:
tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False) tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
tensors_0_1 = tensor.clamp(0, 1) # Ensure values are in [0, 1] range
# Normalize with mean and std mean = torch.tensor(self.image_processor.image_mean).to(
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(tensor.device) tensors_0_1.device, dtype=tensors_0_1.dtype
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(tensor.device) ).view([1, 3, 1, 1]).detach()
tensor = (tensor - mean) / std std = torch.tensor(self.image_processor.image_std).to(
tensors_0_1.device, dtype=tensors_0_1.dtype
).view([1, 3, 1, 1]).detach()
return tensor # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean) / std
return clip_image.detach()
def get_noise_prediction( def get_noise_prediction(
self, self,

View File

@@ -65,6 +65,17 @@ def add_first_frame_conditioning(
video_condition.to(device, dtype) video_condition.to(device, dtype)
).latent_dist.sample() ).latent_dist.sample()
latent_condition = latent_condition.to(device, dtype) latent_condition = latent_condition.to(device, dtype)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(device, dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
device, dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
# Create mask: 1 for conditioning frames, 0 for frames to generate # Create mask: 1 for conditioning frames, 0 for frames to generate
batch_size = first_frame.shape[0] batch_size = first_frame.shape[0]