mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with wan i2v scaling. Adjust aggressive loader to be compatable with updated diffusers.
This commit is contained in:
@@ -99,7 +99,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# unload vae and transformer
|
||||
device = self.transformer.device
|
||||
# device = self.transformer.device
|
||||
device = self._exec_device
|
||||
|
||||
self.text_encoder.to(device)
|
||||
|
||||
@@ -116,7 +117,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -145,7 +147,6 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
||||
)
|
||||
|
||||
# unload text encoder
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
self.transformer.to(device)
|
||||
flush()
|
||||
@@ -456,13 +457,20 @@ class Wan21I2V(Wan21):
|
||||
# Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL)
|
||||
if tensor.shape[2] != 224 or tensor.shape[3] != 224:
|
||||
tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
|
||||
|
||||
tensors_0_1 = tensor.clamp(0, 1) # Ensure values are in [0, 1] range
|
||||
|
||||
# Normalize with mean and std
|
||||
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(tensor.device)
|
||||
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(tensor.device)
|
||||
tensor = (tensor - mean) / std
|
||||
mean = torch.tensor(self.image_processor.image_mean).to(
|
||||
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||
).view([1, 3, 1, 1]).detach()
|
||||
std = torch.tensor(self.image_processor.image_std).to(
|
||||
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||
).view([1, 3, 1, 1]).detach()
|
||||
|
||||
return tensor
|
||||
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean) / std
|
||||
|
||||
return clip_image.detach()
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
|
||||
@@ -65,6 +65,17 @@ def add_first_frame_conditioning(
|
||||
video_condition.to(device, dtype)
|
||||
).latent_dist.sample()
|
||||
latent_condition = latent_condition.to(device, dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, vae.config.z_dim, 1, 1, 1)
|
||||
.to(device, dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||
device, dtype
|
||||
)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
|
||||
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
||||
batch_size = first_frame.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user