mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
BlockSparseMLP: Improved batch routing
This commit is contained in:
@@ -226,6 +226,7 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
|
||||
at::Tensor _interm_g,
|
||||
at::Tensor _interm_u,
|
||||
at::Tensor _interm_a,
|
||||
at::Tensor _interm_a2,
|
||||
at::Tensor _out_d,
|
||||
at::Tensor _out_d2,
|
||||
c10::optional<at::Tensor> _out_d_sh,
|
||||
@@ -270,6 +271,7 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
|
||||
interm_g (std::move(_interm_g)),
|
||||
interm_u (std::move(_interm_u)),
|
||||
interm_a (std::move(_interm_a)),
|
||||
interm_a2 (std::move(_interm_a2)),
|
||||
out_d (std::move(_out_d)),
|
||||
out_d2 (std::move(_out_d2)),
|
||||
out_d_sh (std::move(_out_d_sh)),
|
||||
@@ -342,7 +344,7 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
|
||||
at::Tensor ai = interm_a.slice(0, 0, bsz);
|
||||
at::Tensor ai = interm_a2.slice(0, 0, bsz);
|
||||
at::Tensor oi = out_d2.slice(0, 0, bsz);
|
||||
|
||||
if (use_mgemm)
|
||||
@@ -444,10 +446,10 @@ void BC_BlockSparseMLP::run_single_expert_dq
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
|
||||
at::Tensor yh1 = yh[0];
|
||||
at::Tensor yh2 = yh[0];
|
||||
at::Tensor interm1 = interm[0];
|
||||
at::Tensor interm2 = interm[1];
|
||||
at::Tensor yh1 = yh.slice(0, 0, bsz);
|
||||
at::Tensor yh2 = yh.slice(0, bsz, bsz * 2);
|
||||
at::Tensor interm1 = interm.slice(0, 0, bsz);
|
||||
at::Tensor interm2 = interm.slice(0, bsz, bsz * 2);
|
||||
|
||||
had_r_128_dual(y, yh1, gates[expert_idx]->suh, c10::nullopt,
|
||||
y, yh2, ups[expert_idx]->suh, c10::nullopt, 1.0);
|
||||
|
||||
@@ -26,6 +26,7 @@ struct BC_BlockSparseMLP
|
||||
at::Tensor interm_g;
|
||||
at::Tensor interm_u;
|
||||
at::Tensor interm_a;
|
||||
at::Tensor interm_a2;
|
||||
at::Tensor out_d;
|
||||
at::Tensor out_d2;
|
||||
c10::optional<at::Tensor> out_d_sh;
|
||||
@@ -83,6 +84,7 @@ struct BC_BlockSparseMLP
|
||||
at::Tensor _interm_g,
|
||||
at::Tensor _interm_u,
|
||||
at::Tensor _interm_a,
|
||||
at::Tensor _interm_a2,
|
||||
at::Tensor _out_d,
|
||||
at::Tensor _out_d2,
|
||||
c10::optional<at::Tensor> _out_d_sh,
|
||||
|
||||
@@ -9,6 +9,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
c10::optional<at::Tensor>,
|
||||
c10::optional<at::Tensor>,
|
||||
at::Tensor,
|
||||
@@ -51,6 +52,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
py::arg("interm_g"),
|
||||
py::arg("interm_u"),
|
||||
py::arg("interm_a"),
|
||||
py::arg("interm_a2"),
|
||||
py::arg("out_d"),
|
||||
py::arg("out_d2"),
|
||||
py::arg("out_d_sh"),
|
||||
|
||||
@@ -400,6 +400,7 @@ class BlockSparseMLP(Module):
|
||||
interm_a = temp_activa[:numex].view(numex, 1, I)
|
||||
yh2 = temp_hidden
|
||||
interm_gu = temp_interm
|
||||
interm_a2 = temp_activa
|
||||
out_d = temp_output[:numex].view(numex, 1, H)
|
||||
out_d2 = temp_output
|
||||
|
||||
@@ -457,6 +458,7 @@ class BlockSparseMLP(Module):
|
||||
cfg.interm_g,
|
||||
cfg.interm_u,
|
||||
cfg.interm_a,
|
||||
interm_a2,
|
||||
cfg.out_d,
|
||||
cfg.out_d2,
|
||||
sh_exp_t,
|
||||
@@ -636,7 +638,7 @@ class BlockSparseMLP(Module):
|
||||
|
||||
current_state = y.index_select(0, top_x)
|
||||
|
||||
if self.bc is not None:
|
||||
if self.bc is not None and False:
|
||||
if count <= TEMP_ROWS:
|
||||
self.bc.run_single_expert(current_state, expert_idx)
|
||||
current_state = self.experts_cfg.out_d2[:count]
|
||||
@@ -647,12 +649,13 @@ class BlockSparseMLP(Module):
|
||||
device = self.device
|
||||
)
|
||||
interm = torch.empty(
|
||||
(2, count, self.intermediate_size),
|
||||
(count * 2, self.intermediate_size),
|
||||
dtype = self.interm_dtype,
|
||||
device = self.device
|
||||
)
|
||||
interm_a = interm[0] if self.interm_dtype == torch.half else torch.empty_like(interm[0], dtype = torch.half)
|
||||
yh = torch.empty((2, count, self.hidden_size), dtype = torch.half, device = self.device)
|
||||
interm_a = interm[:count] if self.interm_dtype == torch.half else \
|
||||
torch.empty_like(interm[:count], dtype = torch.half)
|
||||
yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device)
|
||||
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
|
||||
current_state = out_state
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user