Fix issue where sometimes the transformer does not get loaded properly.

This commit is contained in:
Jaret Burkett
2025-08-14 14:24:41 -06:00
parent 3413fa537f
commit 1c96b95617

View File

@@ -134,10 +134,14 @@ class DualWanTransformer3DModel(torch.nn.Module):
getattr(self, t_name).to(self.device_torch)
torch.cuda.empty_cache()
self._active_transformer_name = t_name
if self.transformer.device != hidden_states.device:
raise ValueError(
f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}"
)
if self.low_vram:
# move other transformer to cpu
other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2'
getattr(self, other_tname).to("cpu")
self.transformer.to(hidden_states.device)
return self.transformer(
hidden_states=hidden_states,