BlockSparseMLP: Improved batch routing

This commit is contained in:
turboderp
2026-03-07 02:43:50 +01:00
parent 8e192e12f7
commit 168f21b0ec
4 changed files with 18 additions and 9 deletions

View File

@@ -226,6 +226,7 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
at::Tensor _interm_g,
at::Tensor _interm_u,
at::Tensor _interm_a,
at::Tensor _interm_a2,
at::Tensor _out_d,
at::Tensor _out_d2,
c10::optional<at::Tensor> _out_d_sh,
@@ -270,6 +271,7 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
interm_g (std::move(_interm_g)),
interm_u (std::move(_interm_u)),
interm_a (std::move(_interm_a)),
interm_a2 (std::move(_interm_a2)),
out_d (std::move(_out_d)),
out_d2 (std::move(_out_d2)),
out_d_sh (std::move(_out_d_sh)),
@@ -342,7 +344,7 @@ void BC_BlockSparseMLP::run_single_expert
{
int bsz = y.size(0);
at::Tensor ai = interm_a.slice(0, 0, bsz);
at::Tensor ai = interm_a2.slice(0, 0, bsz);
at::Tensor oi = out_d2.slice(0, 0, bsz);
if (use_mgemm)
@@ -444,10 +446,10 @@ void BC_BlockSparseMLP::run_single_expert_dq
{
int bsz = y.size(0);
at::Tensor yh1 = yh[0];
at::Tensor yh2 = yh[0];
at::Tensor interm1 = interm[0];
at::Tensor interm2 = interm[1];
at::Tensor yh1 = yh.slice(0, 0, bsz);
at::Tensor yh2 = yh.slice(0, bsz, bsz * 2);
at::Tensor interm1 = interm.slice(0, 0, bsz);
at::Tensor interm2 = interm.slice(0, bsz, bsz * 2);
had_r_128_dual(y, yh1, gates[expert_idx]->suh, c10::nullopt,
y, yh2, ups[expert_idx]->suh, c10::nullopt, 1.0);

View File

@@ -26,6 +26,7 @@ struct BC_BlockSparseMLP
at::Tensor interm_g;
at::Tensor interm_u;
at::Tensor interm_a;
at::Tensor interm_a2;
at::Tensor out_d;
at::Tensor out_d2;
c10::optional<at::Tensor> out_d_sh;
@@ -83,6 +84,7 @@ struct BC_BlockSparseMLP
at::Tensor _interm_g,
at::Tensor _interm_u,
at::Tensor _interm_a,
at::Tensor _interm_a2,
at::Tensor _out_d,
at::Tensor _out_d2,
c10::optional<at::Tensor> _out_d_sh,

View File

@@ -9,6 +9,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
at::Tensor,
@@ -51,6 +52,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
py::arg("interm_g"),
py::arg("interm_u"),
py::arg("interm_a"),
py::arg("interm_a2"),
py::arg("out_d"),
py::arg("out_d2"),
py::arg("out_d_sh"),

View File

@@ -400,6 +400,7 @@ class BlockSparseMLP(Module):
interm_a = temp_activa[:numex].view(numex, 1, I)
yh2 = temp_hidden
interm_gu = temp_interm
interm_a2 = temp_activa
out_d = temp_output[:numex].view(numex, 1, H)
out_d2 = temp_output
@@ -457,6 +458,7 @@ class BlockSparseMLP(Module):
cfg.interm_g,
cfg.interm_u,
cfg.interm_a,
interm_a2,
cfg.out_d,
cfg.out_d2,
sh_exp_t,
@@ -636,7 +638,7 @@ class BlockSparseMLP(Module):
current_state = y.index_select(0, top_x)
if self.bc is not None:
if self.bc is not None and False:
if count <= TEMP_ROWS:
self.bc.run_single_expert(current_state, expert_idx)
current_state = self.experts_cfg.out_d2[:count]
@@ -647,12 +649,13 @@ class BlockSparseMLP(Module):
device = self.device
)
interm = torch.empty(
(2, count, self.intermediate_size),
(count * 2, self.intermediate_size),
dtype = self.interm_dtype,
device = self.device
)
interm_a = interm[0] if self.interm_dtype == torch.half else torch.empty_like(interm[0], dtype = torch.half)
yh = torch.empty((2, count, self.hidden_size), dtype = torch.half, device = self.device)
interm_a = interm[:count] if self.interm_dtype == torch.half else \
torch.empty_like(interm[:count], dtype = torch.half)
yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device)
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
current_state = out_state
else: