From 1c96b95617f6e2dd927c05fb8931842baeb987a7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Aug 2025 14:24:41 -0600 Subject: [PATCH] Fix issue where sometimes the transformer does not get loaded properly. --- .../diffusion_models/wan22/wan22_14b_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index 3266d320..79886ea4 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -134,10 +134,14 @@ class DualWanTransformer3DModel(torch.nn.Module): getattr(self, t_name).to(self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name + if self.transformer.device != hidden_states.device: - raise ValueError( - f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}" - ) + if self.low_vram: + # move other transformer to cpu + other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2' + getattr(self, other_tname).to("cpu") + + self.transformer.to(hidden_states.device) return self.transformer( hidden_states=hidden_states,