Fixed cuda error when not all tensors have been moved to the correct device.

This commit is contained in:
Jaret Burkett
2025-03-07 22:04:35 -07:00
parent 25341c4613
commit 4d88f8f218

View File

@@ -117,6 +117,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
text_encoder_device = self.text_encoder.device
print("Unloading vae")
self.vae.to("cpu")
self.text_encoder.to(self._execution_device)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -159,6 +160,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
# unload text encoder
print("Unloading text encoder")
self.text_encoder.to("cpu")
self.transformer.to(self._execution_device)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)