mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fix issue with device placement on te
This commit is contained in:
@@ -83,6 +83,27 @@ scheduler_config = {
|
|||||||
|
|
||||||
|
|
||||||
class AggressiveWanUnloadPipeline(WanPipeline):
|
class AggressiveWanUnloadPipeline(WanPipeline):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
text_encoder: UMT5EncoderModel,
|
||||||
|
transformer: WanTransformer3DModel,
|
||||||
|
vae: AutoencoderKLWan,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
device: torch.device = torch.device("cuda"),
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
transformer=transformer,
|
||||||
|
vae=vae,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self._exec_device = device
|
||||||
|
@property
|
||||||
|
def _execution_device(self):
|
||||||
|
return self._exec_device
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self: WanPipeline,
|
self: WanPipeline,
|
||||||
prompt: Union[str, List[str]] = None,
|
prompt: Union[str, List[str]] = None,
|
||||||
@@ -459,6 +480,7 @@ class Wan21(BaseModel):
|
|||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
device=self.device_torch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pipeline = WanPipeline(
|
pipeline = WanPipeline(
|
||||||
@@ -540,6 +562,8 @@ class Wan21(BaseModel):
|
|||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||||
|
if self.pipeline.text_encoder.device != self.device_torch:
|
||||||
|
self.pipeline.text_encoder.to(self.device_torch)
|
||||||
prompt_embeds, _ = self.pipeline.encode_prompt(
|
prompt_embeds, _ = self.pipeline.encode_prompt(
|
||||||
prompt,
|
prompt,
|
||||||
do_classifier_free_guidance=False,
|
do_classifier_free_guidance=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user