diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index 3266d320..79886ea4 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -134,10 +134,14 @@ class DualWanTransformer3DModel(torch.nn.Module): getattr(self, t_name).to(self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name + if self.transformer.device != hidden_states.device: - raise ValueError( - f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}" - ) + if self.low_vram: + # move other transformer to cpu + other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2' + getattr(self, other_tname).to("cpu") + + self.transformer.to(hidden_states.device) return self.transformer( hidden_states=hidden_states,