GatedDeltaNet: Set chunked seqlen threshold to num_v_heads (prevents warning from FLA)

This commit is contained in:
turboderp
2026-03-03 22:33:58 +01:00
parent d3d76d38f8
commit eb1686a840

View File

@@ -618,7 +618,7 @@ class GatedDeltaNet(Module):
# Use chunked rule when advantageous and available
# TODO: At least warn if chunked rule (i.e. flash-linear-attention) is not available
# since performance will tank on prompt ingestion
if seqlen >= 32 and chunk_gated_delta_rule is not None:
if seqlen >= self.num_v_heads and chunk_gated_delta_rule is not None:
mixed_qkv = mixed_qkv.transpose(1, 2)
q, k, v = torch.split(mixed_qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1)