Fix issue with device placement on te

This commit is contained in:
Jaret Burkett
2025-03-13 20:48:12 -06:00
parent 3b45892b4f
commit 391329dbdc

View File

@@ -83,6 +83,27 @@ scheduler_config = {
class AggressiveWanUnloadPipeline(WanPipeline):
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
device: torch.device = torch.device("cuda"),
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
self._exec_device = device
@property
def _execution_device(self):
return self._exec_device
def __call__(
self: WanPipeline,
prompt: Union[str, List[str]] = None,
@@ -459,6 +480,7 @@ class Wan21(BaseModel):
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
device=self.device_torch
)
else:
pipeline = WanPipeline(
@@ -540,6 +562,8 @@ class Wan21(BaseModel):
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,