Check state dict key to auto enable the index_timestep_zero ref method. (#11362)

This commit is contained in:
comfyanonymous
2025-12-16 14:03:17 -08:00
committed by GitHub
parent 65e2103b09
commit ffdd53b327
2 changed files with 6 additions and 1 deletions

View File

@@ -363,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
if self.default_ref_method == "index_timestep_zero":
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)