mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 05:29:48 +00:00
Fix issue where sometimes the transformer does not get loaded properly.
This commit is contained in:
@@ -134,10 +134,14 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
||||
getattr(self, t_name).to(self.device_torch)
|
||||
torch.cuda.empty_cache()
|
||||
self._active_transformer_name = t_name
|
||||
|
||||
if self.transformer.device != hidden_states.device:
|
||||
raise ValueError(
|
||||
f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}"
|
||||
)
|
||||
if self.low_vram:
|
||||
# move other transformer to cpu
|
||||
other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2'
|
||||
getattr(self, other_tname).to("cpu")
|
||||
|
||||
self.transformer.to(hidden_states.device)
|
||||
|
||||
return self.transformer(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user