From 391329dbdc3e6269c5945863407d0e561f8a3e88 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 13 Mar 2025 20:48:12 -0600 Subject: [PATCH] Fix issue with device placement on te --- toolkit/models/wan21/wan21.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 9ef6891f..886cbcc6 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -83,6 +83,27 @@ scheduler_config = { class AggressiveWanUnloadPipeline(WanPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + device: torch.device = torch.device("cuda"), + ): + super().__init__( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + self._exec_device = device + @property + def _execution_device(self): + return self._exec_device + def __call__( self: WanPipeline, prompt: Union[str, List[str]] = None, @@ -459,6 +480,7 @@ class Wan21(BaseModel): text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=scheduler, + device=self.device_torch ) else: pipeline = WanPipeline( @@ -540,6 +562,8 @@ class Wan21(BaseModel): return noise_pred def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) prompt_embeds, _ = self.pipeline.encode_prompt( prompt, do_classifier_free_guidance=False,