From 7e37918fbcd459cd5cec48bc25638674e243b4e3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 7 Mar 2025 22:15:24 -0700 Subject: [PATCH] Double tap module casting as it doesent seem to happen every time. --- toolkit/models/wan21.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index 10e1b9c8..af1ede2b 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -327,7 +327,7 @@ class Wan21(BaseModel): transformer_path, subfolder=subfolder, torch_dtype=dtype, - ) + ).to(dtype=dtype) if self.model_config.split_model_over_gpus: raise ValueError( @@ -396,7 +396,7 @@ class Wan21(BaseModel): tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder = UMT5EncoderModel.from_pretrained( - base_model_path, subfolder="text_encoder", torch_dtype=dtype) + base_model_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) flush() @@ -416,7 +416,7 @@ class Wan21(BaseModel): self.print_and_status_update("Loading VAE") # todo, example does float 32? check if quality suffers vae = AutoencoderKLWan.from_pretrained( - base_model_path, subfolder="vae", torch_dtype=dtype) + base_model_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) flush() self.print_and_status_update("Making pipe")