mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with wan i2v scaling. Adjust aggressive loader to be compatable with updated diffusers.
This commit is contained in:
@@ -99,7 +99,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
# unload vae and transformer
|
# unload vae and transformer
|
||||||
device = self.transformer.device
|
# device = self.transformer.device
|
||||||
|
device = self._exec_device
|
||||||
|
|
||||||
self.text_encoder.to(device)
|
self.text_encoder.to(device)
|
||||||
|
|
||||||
@@ -116,7 +117,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
width,
|
width,
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
negative_prompt_embeds,
|
negative_prompt_embeds,
|
||||||
callback_on_step_end_tensor_inputs,
|
image_embeds=None,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
@@ -145,7 +147,6 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# unload text encoder
|
# unload text encoder
|
||||||
print("Unloading text encoder")
|
|
||||||
self.text_encoder.to("cpu")
|
self.text_encoder.to("cpu")
|
||||||
self.transformer.to(device)
|
self.transformer.to(device)
|
||||||
flush()
|
flush()
|
||||||
@@ -456,13 +457,20 @@ class Wan21I2V(Wan21):
|
|||||||
# Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL)
|
# Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL)
|
||||||
if tensor.shape[2] != 224 or tensor.shape[3] != 224:
|
if tensor.shape[2] != 224 or tensor.shape[3] != 224:
|
||||||
tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
|
tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
tensors_0_1 = tensor.clamp(0, 1) # Ensure values are in [0, 1] range
|
||||||
|
|
||||||
# Normalize with mean and std
|
mean = torch.tensor(self.image_processor.image_mean).to(
|
||||||
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(tensor.device)
|
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||||
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(tensor.device)
|
).view([1, 3, 1, 1]).detach()
|
||||||
tensor = (tensor - mean) / std
|
std = torch.tensor(self.image_processor.image_std).to(
|
||||||
|
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||||
|
).view([1, 3, 1, 1]).detach()
|
||||||
|
|
||||||
return tensor
|
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||||
|
clip_image = (tensors_0_1 - mean) / std
|
||||||
|
|
||||||
|
return clip_image.detach()
|
||||||
|
|
||||||
def get_noise_prediction(
|
def get_noise_prediction(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -65,6 +65,17 @@ def add_first_frame_conditioning(
|
|||||||
video_condition.to(device, dtype)
|
video_condition.to(device, dtype)
|
||||||
).latent_dist.sample()
|
).latent_dist.sample()
|
||||||
latent_condition = latent_condition.to(device, dtype)
|
latent_condition = latent_condition.to(device, dtype)
|
||||||
|
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(vae.config.latents_mean)
|
||||||
|
.view(1, vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(device, dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||||
|
device, dtype
|
||||||
|
)
|
||||||
|
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||||
|
|
||||||
|
|
||||||
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
||||||
batch_size = first_frame.shape[0]
|
batch_size = first_frame.shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user