mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-28 02:44:04 +00:00
Fix dtype issue in embeddings connector. (#12570)
This commit is contained in:
@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
Reference in New Issue
Block a user