mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed cuda error when not all tensors have been moved to the correct device.
This commit is contained in:
@@ -117,6 +117,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
text_encoder_device = self.text_encoder.device
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
self.text_encoder.to(self._execution_device)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -159,6 +160,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
# unload text encoder
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
|
||||
self.transformer.to(self._execution_device)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
|
||||
Reference in New Issue
Block a user