BlockSparseMLP: Skip redundant gather

This commit is contained in:
turboderp
2026-03-11 20:25:56 +01:00
parent d52c49c17f
commit 1b9e58c9b5

View File

@@ -615,12 +615,12 @@ class BlockSparseMLP(Module):
# Group once by local expert id (including sentinel for expert-P mode)
order = flat_expert_local.argsort()
local_sorted = flat_expert_local[order]
# local_sorted = flat_expert_local[order]
token_sorted = flat_token[order]
weight_sorted = flat_weight[order]
# Count how many assignments per expert
expert_count = torch.bincount(local_sorted, minlength = E + 1)
expert_count = torch.bincount(flat_expert_local, minlength = E + 1)
expert_ptr = torch.empty(E + 2, device = y.device, dtype = torch.long)
expert_ptr[0] = 0
expert_ptr[1:] = expert_count.cumsum(0)