mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with device placement on te
This commit is contained in:
@@ -83,6 +83,27 @@ scheduler_config = {
|
||||
|
||||
|
||||
class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
):
|
||||
super().__init__(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self._exec_device = device
|
||||
@property
|
||||
def _execution_device(self):
|
||||
return self._exec_device
|
||||
|
||||
def __call__(
|
||||
self: WanPipeline,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -459,6 +480,7 @@ class Wan21(BaseModel):
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
device=self.device_torch
|
||||
)
|
||||
else:
|
||||
pipeline = WanPipeline(
|
||||
@@ -540,6 +562,8 @@ class Wan21(BaseModel):
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
prompt_embeds, _ = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
|
||||
Reference in New Issue
Block a user