diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index 115f34c..4dbb31e 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -615,12 +615,12 @@ class BlockSparseMLP(Module): # Group once by local expert id (including sentinel for expert-P mode) order = flat_expert_local.argsort() - local_sorted = flat_expert_local[order] + # local_sorted = flat_expert_local[order] token_sorted = flat_token[order] weight_sorted = flat_weight[order] # Count how many assignments per expert - expert_count = torch.bincount(local_sorted, minlength = E + 1) + expert_count = torch.bincount(flat_expert_local, minlength = E + 1) expert_ptr = torch.empty(E + 2, device = y.device, dtype = torch.long) expert_ptr[0] = 0 expert_ptr[1:] = expert_count.cumsum(0)