From 7ad51c0422fd8ea74b7b4960c38bd21209c4f647 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 7 Mar 2026 04:54:02 +0100 Subject: [PATCH] BlockSparseMLP: Keep buffers between experts when possible --- exllamav3/modules/block_sparse_mlp.py | 49 +++++++++++++++++++-------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index 1646c5f..115f34c 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -626,6 +626,11 @@ class BlockSparseMLP(Module): expert_ptr[1:] = expert_count.cumsum(0) expert_ptr = expert_ptr.tolist() + out_state = None + interm = None + interm_a = None + max_count = 0 + for expert_idx in range(num_ex): start = expert_ptr[expert_idx] end = expert_ptr[expert_idx + 1] @@ -638,26 +643,40 @@ class BlockSparseMLP(Module): current_state = y.index_select(0, top_x) - if self.bc is not None and False: + if self.bc is not None: if count <= TEMP_ROWS: self.bc.run_single_expert(current_state, expert_idx) current_state = self.experts_cfg.out_d2[:count] else: - out_state = torch.empty( - (count, self.hidden_size), - dtype = self.out_dtype or torch.half, - device = self.device - ) - interm = torch.empty( - (count * 2, self.intermediate_size), - dtype = self.interm_dtype, - device = self.device - ) - interm_a = interm[:count] if self.interm_dtype == torch.half else \ - torch.empty_like(interm[:count], dtype = torch.half) + if count > max_count: + out_state = torch.empty( + (count, self.hidden_size), + dtype = self.out_dtype or torch.half, + device = self.device + ) + interm = torch.empty( + (count * 2, self.intermediate_size), + dtype = self.interm_dtype, + device = self.device + ) + interm_a = interm[:count] if self.interm_dtype == torch.half else \ + torch.empty_like(interm[:count], dtype = torch.half) + out_state_ = out_state + interm_ = interm + interm_a_ = interm_a + max_count = count + elif count == max_count: + out_state_ = out_state + interm_ = interm + interm_a_ = interm_a + else: + out_state_ = out_state[:count] + interm_ = interm[:count * 2] + interm_a_ = interm_a[:count] + yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device) - self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state) - current_state = out_state + self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm_, interm_a_, out_state) + current_state = out_state_ else: def mlp(exp_i, xc): g = self.gates[exp_i].forward(xc, params)