BlockSparseMLP: Keep buffers between experts when possible

This commit is contained in:
turboderp
2026-03-07 04:54:02 +01:00
parent 168f21b0ec
commit 7ad51c0422

View File

@@ -626,6 +626,11 @@ class BlockSparseMLP(Module):
expert_ptr[1:] = expert_count.cumsum(0)
expert_ptr = expert_ptr.tolist()
out_state = None
interm = None
interm_a = None
max_count = 0
for expert_idx in range(num_ex):
start = expert_ptr[expert_idx]
end = expert_ptr[expert_idx + 1]
@@ -638,26 +643,40 @@ class BlockSparseMLP(Module):
current_state = y.index_select(0, top_x)
if self.bc is not None and False:
if self.bc is not None:
if count <= TEMP_ROWS:
self.bc.run_single_expert(current_state, expert_idx)
current_state = self.experts_cfg.out_d2[:count]
else:
out_state = torch.empty(
(count, self.hidden_size),
dtype = self.out_dtype or torch.half,
device = self.device
)
interm = torch.empty(
(count * 2, self.intermediate_size),
dtype = self.interm_dtype,
device = self.device
)
interm_a = interm[:count] if self.interm_dtype == torch.half else \
torch.empty_like(interm[:count], dtype = torch.half)
if count > max_count:
out_state = torch.empty(
(count, self.hidden_size),
dtype = self.out_dtype or torch.half,
device = self.device
)
interm = torch.empty(
(count * 2, self.intermediate_size),
dtype = self.interm_dtype,
device = self.device
)
interm_a = interm[:count] if self.interm_dtype == torch.half else \
torch.empty_like(interm[:count], dtype = torch.half)
out_state_ = out_state
interm_ = interm
interm_a_ = interm_a
max_count = count
elif count == max_count:
out_state_ = out_state
interm_ = interm
interm_a_ = interm_a
else:
out_state_ = out_state[:count]
interm_ = interm[:count * 2]
interm_a_ = interm_a[:count]
yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device)
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
current_state = out_state
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm_, interm_a_, out_state)
current_state = out_state_
else:
def mlp(exp_i, xc):
g = self.gates[exp_i].forward(xc, params)