mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Fixed cuda error when not all tensors have been moved to the correct device.
This commit is contained in:
@@ -117,6 +117,7 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
text_encoder_device = self.text_encoder.device
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
self.text_encoder.to(self._execution_device)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -160,6 +161,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
|
||||
self.transformer.to(self._execution_device)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
|
||||
Reference in New Issue
Block a user