Fixed cuda error when not all tensors have been moved to the correct device.

This commit is contained in:
Jaret Burkett
2025-03-07 22:04:35 -07:00
parent 25341c4613
commit 4d88f8f218

View File

@@ -117,6 +117,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
text_encoder_device = self.text_encoder.device
print("Unloading vae")
self.vae.to("cpu")
self.text_encoder.to(self._execution_device)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -160,6 +161,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
print("Unloading text encoder")
self.text_encoder.to("cpu")
self.transformer.to(self._execution_device)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None: