Fix dtype issue in embeddings connector. (#12570)

This commit is contained in:
comfyanonymous
2026-02-22 00:18:20 -08:00
committed by GitHub
parent f266b8d352
commit 07ca6852e8

View File

@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward(
self,
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):