diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index fdd76051..b3f4af3f 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -302,6 +302,8 @@ class Wan2214bModel(Wan21): if self.model_config.low_vram: self.print_and_status_update("Moving transformer 1 to CPU") transformer_1.to("cpu") + else: + transformer_1.to(self.device_torch) self.print_and_status_update("Loading transformer 2") dtype = self.torch_dtype @@ -327,6 +329,8 @@ class Wan2214bModel(Wan21): if self.model_config.low_vram: self.print_and_status_update("Moving transformer 2 to CPU") transformer_2.to("cpu") + else: + transformer_2.to(self.device_torch) # make the combined model self.print_and_status_update("Creating DualWanTransformer3DModel")