diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index dee07040..10e1b9c8 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -117,6 +117,7 @@ class AggressiveWanUnloadPipeline(WanPipeline): text_encoder_device = self.text_encoder.device print("Unloading vae") self.vae.to("cpu") + self.text_encoder.to(self._execution_device) # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -159,6 +160,8 @@ class AggressiveWanUnloadPipeline(WanPipeline): # unload text encoder print("Unloading text encoder") self.text_encoder.to("cpu") + + self.transformer.to(self._execution_device) transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype)