BlockSparseMLP: Improved batch routing

This commit is contained in:
turboderp
2026-03-07 01:21:34 +01:00
parent 86174510bd
commit 766a28dc60
6 changed files with 615 additions and 109 deletions

View File

@@ -6,6 +6,8 @@
#include "../util.h"
#include "../hgemm.cuh"
#include "../quant/exl3_gemm.cuh"
#include "../quant/hadamard.cuh"
#include "../quant/reconstruct.cuh"
#include "../activation.cuh"
#include "../add.cuh"
@@ -69,7 +71,7 @@ void BC_BlockSparseMLP::run_bsz1_gr
Graph* graph
)
{
py::gil_scoped_release _;
//py::gil_scoped_release _;
const at::Tensor& yi = y.unsqueeze(0);
exl3_mgemm_gr
@@ -213,4 +215,258 @@ void BC_BlockSparseMLP::run_bsz1
graph_bsz1.launch(args, stream);
#endif
#undef USE_GRAPH
}
BC_BlockSparseMLP::BC_BlockSparseMLP
(
at::Tensor _yh2,
at::Tensor _yh,
at::Tensor _interm_gu,
at::Tensor _interm_g,
at::Tensor _interm_u,
at::Tensor _interm_a,
at::Tensor _out_d,
at::Tensor _out_d2,
c10::optional<at::Tensor> _out_d_sh,
c10::optional<at::Tensor> _z,
at::Tensor _dq_temp_up,
at::Tensor _dq_temp_down,
int _min_expert,
int _max_expert,
at::Tensor _gate_ptrs_trellis,
at::Tensor _gate_ptrs_suh,
at::Tensor _gate_ptrs_svh,
int _gate_K,
bool _gate_mcg,
bool _gate_mul1,
at::Tensor _up_ptrs_trellis,
at::Tensor _up_ptrs_suh,
at::Tensor _up_ptrs_svh,
int _up_K,
bool _up_mcg,
bool _up_mul1,
at::Tensor _down_ptrs_trellis,
at::Tensor _down_ptrs_suh,
at::Tensor _down_ptrs_svh,
int _down_K,
bool _down_mcg,
bool _down_mul1,
bool _act_silu,
bool _act_gelu,
std::shared_ptr<BC_GatedMLP> _shared_experts,
std::shared_ptr<BC_LinearFP16> _shared_gate,
float _act_limit,
std::vector<std::shared_ptr<BC_LinearEXL3>> _gates,
std::vector<std::shared_ptr<BC_LinearEXL3>> _ups,
std::vector<std::shared_ptr<BC_LinearEXL3>> _downs,
at::Tensor _gu_trellis_ptr,
at::Tensor _gu_suh_ptr,
at::Tensor _gu_svh_ptr
) :
yh2 (std::move(_yh2)),
yh (std::move(_yh)),
interm_gu (std::move(_interm_gu)),
interm_g (std::move(_interm_g)),
interm_u (std::move(_interm_u)),
interm_a (std::move(_interm_a)),
out_d (std::move(_out_d)),
out_d2 (std::move(_out_d2)),
out_d_sh (std::move(_out_d_sh)),
z (std::move(_z)),
dq_temp_up (std::move(_dq_temp_up)),
dq_temp_down (std::move(_dq_temp_down)),
min_expert (_min_expert),
max_expert (_max_expert),
gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
gate_ptrs_suh (std::move(_gate_ptrs_suh)),
gate_ptrs_svh (std::move(_gate_ptrs_svh)),
gate_K (_gate_K),
gate_mcg (_gate_mcg),
gate_mul1 (_gate_mul1),
up_ptrs_trellis (std::move(_up_ptrs_trellis)),
up_ptrs_suh (std::move(_up_ptrs_suh)),
up_ptrs_svh (std::move(_up_ptrs_svh)),
up_K (_up_K),
up_mcg (_up_mcg),
up_mul1 (_up_mul1),
down_ptrs_trellis (std::move(_down_ptrs_trellis)),
down_ptrs_suh (std::move(_down_ptrs_suh)),
down_ptrs_svh (std::move(_down_ptrs_svh)),
down_K (_down_K),
down_mcg (_down_mcg),
down_mul1 (_down_mul1),
act_silu (_act_silu),
act_gelu (_act_gelu),
shared_experts (_shared_experts),
shared_gate (_shared_gate),
act_limit (_act_limit),
gates (_gates),
ups (_ups),
downs (_downs),
gu_trellis_ptr (_gu_trellis_ptr),
gu_suh_ptr (_gu_suh_ptr),
gu_svh_ptr (_gu_svh_ptr)
{
gate_ptrs_trellis_cpu = gate_ptrs_trellis.cpu();
gate_ptrs_suh_cpu = gate_ptrs_suh.cpu();
gate_ptrs_svh_cpu = gate_ptrs_svh.cpu();
up_ptrs_trellis_cpu = up_ptrs_trellis.cpu();
up_ptrs_suh_cpu = up_ptrs_suh.cpu();
up_ptrs_svh_cpu = up_ptrs_svh.cpu();
down_ptrs_trellis_cpu = down_ptrs_trellis.cpu();
down_ptrs_suh_cpu = down_ptrs_suh.cpu();
down_ptrs_svh_cpu = down_ptrs_svh.cpu();
max_experts_per_token = interm_g.size(0);
max_tokens_per_expert = max_experts_per_token;
for (int i = 0; i < max_tokens_per_expert; ++i)
{
interm_g_single.push_back(interm_g.squeeze(1).slice(0, 0, i + 1));
interm_u_single.push_back(interm_u.squeeze(1).slice(0, 0, i + 1));
interm_a_single.push_back(interm_a.squeeze(1).slice(0, 0, i + 1));
out_d_single.push_back(out_d.squeeze(1).slice(0, 0, i + 1));
}
TORCH_CHECK(max_expert <= MAX_EXPERTS, "BC_BlockSparseMLP: Too many experts");
use_mgemm = gate_K == up_K;
}
void BC_BlockSparseMLP::run_single_expert
(
const at::Tensor& y,
const int expert_idx
)
{
int bsz = y.size(0);
at::Tensor ai = interm_a.slice(0, 0, bsz);
at::Tensor oi = out_d2.slice(0, 0, bsz);
if (use_mgemm)
{
at::Tensor yb = y.unsqueeze(0);
at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
exl3_mgemm
(
yb,
gu_trellis_ptr[expert_idx],
gui,
gu_suh_ptr[expert_idx],
yh2,
gu_svh_ptr[expert_idx],
c10::nullopt,
c10::nullopt,
gate_K,
-1,
gate_mcg,
gate_mul1,
-1,
-1,
0
);
at::Tensor gi = gui[0];
at::Tensor ui = gui[1];
if (act_silu)
silu_mul(gi, ui, ai, act_limit);
else if (act_gelu)
gelu_mul(gi, ui, ai, act_limit);
}
else
{
at::Tensor gi = interm_gu.slice(0, 0, bsz);
at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2);
exl3_gemm
(
y,
gates[expert_idx]->trellis,
gi,
gates[expert_idx]->suh,
yh,
gates[expert_idx]->svh,
-1,
gate_mcg,
gate_mul1,
0
);
exl3_gemm
(
y,
ups[expert_idx]->trellis,
ui,
ups[expert_idx]->suh,
yh,
ups[expert_idx]->svh,
-1,
up_mcg,
up_mul1,
0
);
if (act_silu)
silu_mul(gi, ui, ai, act_limit);
else if (act_gelu)
gelu_mul(gi, ui, ai, act_limit);
}
exl3_gemm
(
ai,
downs[expert_idx]->trellis,
oi,
downs[expert_idx]->suh,
ai,
downs[expert_idx]->svh,
-1,
down_mcg,
down_mul1,
0
);
}
void BC_BlockSparseMLP::run_single_expert_dq
(
const at::Tensor& y,
const int expert_idx,
at::Tensor& yh,
at::Tensor& interm,
at::Tensor& interm_a,
at::Tensor& out
)
{
int bsz = y.size(0);
at::Tensor yh1 = yh[0];
at::Tensor yh2 = yh[0];
at::Tensor interm1 = interm[0];
at::Tensor interm2 = interm[1];
had_r_128_dual(y, yh1, gates[expert_idx]->suh, c10::nullopt,
y, yh2, ups[expert_idx]->suh, c10::nullopt, 1.0);
reconstruct(dq_temp_up, gates[expert_idx]->trellis, gate_K, gate_mcg, gate_mul1);
hgemm(yh1, dq_temp_up, interm1);
reconstruct(dq_temp_up, ups[expert_idx]->trellis, up_K, up_mcg, up_mul1);
hgemm(yh2, dq_temp_up, interm2);
had_r_128_dual(interm1, interm1, c10::nullopt, gates[expert_idx]->svh,
interm2, interm2, c10::nullopt, ups[expert_idx]->svh, 1.0);
if (act_silu)
silu_mul(interm1, interm2, interm_a, act_limit);
else if (act_gelu)
gelu_mul(interm1, interm2, interm_a, act_limit);
had_r_128(interm_a, interm_a, downs[expert_idx]->suh, c10::nullopt, 1.0);
reconstruct(dq_temp_down, downs[expert_idx]->trellis, down_K, down_mcg, down_mul1);
hgemm(interm_a, dq_temp_down, out);
had_r_128(out, out, c10::nullopt, downs[expert_idx]->svh, 1.0);
}

View File

@@ -9,6 +9,8 @@ namespace py = pybind11;
#include "linear.h"
#include "../graph.cuh"
#define MAX_EXPERTS 512
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
int bsz,
const py::object& cfg,
@@ -18,30 +20,35 @@ std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
struct BC_BlockSparseMLP
{
at::Tensor yh2;
at::Tensor yh;
at::Tensor interm_gu;
at::Tensor interm_g;
at::Tensor interm_u;
at::Tensor interm_a;
at::Tensor out_d;
at::Tensor out_d2;
c10::optional<at::Tensor> out_d_sh;
c10::optional<at::Tensor> z;
at::Tensor dq_temp_up;
at::Tensor dq_temp_down;
int min_expert;
int max_expert;
at::Tensor gate_ptrs_trellis;
at::Tensor gate_ptrs_suh;
at::Tensor gate_ptrs_svh;
at::Tensor gate_ptrs_trellis; at::Tensor gate_ptrs_trellis_cpu;
at::Tensor gate_ptrs_suh; at::Tensor gate_ptrs_suh_cpu;
at::Tensor gate_ptrs_svh; at::Tensor gate_ptrs_svh_cpu;
int gate_K;
bool gate_mcg;
bool gate_mul1;
at::Tensor up_ptrs_trellis;
at::Tensor up_ptrs_suh;
at::Tensor up_ptrs_svh;
at::Tensor up_ptrs_trellis; at::Tensor up_ptrs_trellis_cpu;
at::Tensor up_ptrs_suh; at::Tensor up_ptrs_suh_cpu;
at::Tensor up_ptrs_svh; at::Tensor up_ptrs_svh_cpu;
int up_K;
bool up_mcg;
bool up_mul1;
at::Tensor down_ptrs_trellis;
at::Tensor down_ptrs_suh;
at::Tensor down_ptrs_svh;
at::Tensor down_ptrs_trellis; at::Tensor down_ptrs_trellis_cpu;
at::Tensor down_ptrs_suh; at::Tensor down_ptrs_suh_cpu;
at::Tensor down_ptrs_svh; at::Tensor down_ptrs_svh_cpu;
int down_K;
bool down_mcg;
bool down_mul1;
@@ -50,18 +57,38 @@ struct BC_BlockSparseMLP
std::shared_ptr<BC_GatedMLP> shared_experts;
std::shared_ptr<BC_LinearFP16> shared_gate;
float act_limit;
std::vector<std::shared_ptr<BC_LinearEXL3>> gates;
std::vector<std::shared_ptr<BC_LinearEXL3>> ups;
std::vector<std::shared_ptr<BC_LinearEXL3>> downs;
at::Tensor gu_trellis_ptr;
at::Tensor gu_suh_ptr;
at::Tensor gu_svh_ptr;
int max_experts_per_token;
int max_tokens_per_expert;
std::vector<at::Tensor> interm_g_single;
std::vector<at::Tensor> interm_u_single;
std::vector<at::Tensor> interm_a_single;
std::vector<at::Tensor> out_d_single;
bool use_mgemm;
Graph graph_bsz1;
BC_BlockSparseMLP
(
at::Tensor _yh2,
at::Tensor _yh,
at::Tensor _interm_gu,
at::Tensor _interm_g,
at::Tensor _interm_u,
at::Tensor _interm_a,
at::Tensor _out_d,
at::Tensor _out_d2,
c10::optional<at::Tensor> _out_d_sh,
c10::optional<at::Tensor> _z,
at::Tensor _dq_temp_up,
at::Tensor _dq_temp_down,
int _min_expert,
int _max_expert,
at::Tensor _gate_ptrs_trellis,
@@ -86,41 +113,14 @@ struct BC_BlockSparseMLP
bool _act_gelu,
std::shared_ptr<BC_GatedMLP> _shared_experts,
std::shared_ptr<BC_LinearFP16> _shared_gate,
float _act_limit
) :
yh (std::move(_yh)),
interm_g (std::move(_interm_g)),
interm_u (std::move(_interm_u)),
interm_a (std::move(_interm_a)),
out_d (std::move(_out_d)),
out_d_sh (std::move(_out_d_sh)),
z (std::move(_z)),
min_expert (_min_expert),
max_expert (_max_expert),
gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
gate_ptrs_suh (std::move(_gate_ptrs_suh)),
gate_ptrs_svh (std::move(_gate_ptrs_svh)),
gate_K (_gate_K),
gate_mcg (_gate_mcg),
gate_mul1 (_gate_mul1),
up_ptrs_trellis (std::move(_up_ptrs_trellis)),
up_ptrs_suh (std::move(_up_ptrs_suh)),
up_ptrs_svh (std::move(_up_ptrs_svh)),
up_K (_up_K),
up_mcg (_up_mcg),
up_mul1 (_up_mul1),
down_ptrs_trellis (std::move(_down_ptrs_trellis)),
down_ptrs_suh (std::move(_down_ptrs_suh)),
down_ptrs_svh (std::move(_down_ptrs_svh)),
down_K (_down_K),
down_mcg (_down_mcg),
down_mul1 (_down_mul1),
act_silu (_act_silu),
act_gelu (_act_gelu),
shared_experts (_shared_experts),
shared_gate (_shared_gate),
act_limit (_act_limit)
{}
float _act_limit,
std::vector<std::shared_ptr<BC_LinearEXL3>> _gates,
std::vector<std::shared_ptr<BC_LinearEXL3>> _ups,
std::vector<std::shared_ptr<BC_LinearEXL3>> _downs,
at::Tensor _gu_trellis_ptr,
at::Tensor _gu_suh_ptr,
at::Tensor _gu_svh_ptr
);
void run_bsz1_gr
(
@@ -136,4 +136,20 @@ struct BC_BlockSparseMLP
at::Tensor& selected_experts,
at::Tensor& routing_weights
);
void run_single_expert
(
const at::Tensor& y,
const int expert_idx
);
void run_single_expert_dq
(
const at::Tensor& y,
const int expert_idx,
at::Tensor& yh,
at::Tensor& interm,
at::Tensor& interm_a,
at::Tensor& out
);
};

View File

@@ -6,8 +6,13 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
at::Tensor,
at::Tensor,
int,
int,
at::Tensor,
@@ -32,15 +37,26 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
bool,
std::shared_ptr<BC_GatedMLP>,
std::shared_ptr<BC_LinearFP16>,
float
float,
std::vector<std::shared_ptr<BC_LinearEXL3>>,
std::vector<std::shared_ptr<BC_LinearEXL3>>,
std::vector<std::shared_ptr<BC_LinearEXL3>>,
at::Tensor,
at::Tensor,
at::Tensor
>(),
py::arg("yh2"),
py::arg("yh"),
py::arg("interm_gu"),
py::arg("interm_g"),
py::arg("interm_u"),
py::arg("interm_a"),
py::arg("out_d"),
py::arg("out_d2"),
py::arg("out_d_sh"),
py::arg("z"),
py::arg("dq_temp_up"),
py::arg("dq_temp_down"),
py::arg("min_expert"),
py::arg("max_expert"),
py::arg("gate_ptrs_trellis"),
@@ -65,6 +81,14 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
py::arg("act_gelu"),
py::arg("shared_experts"),
py::arg("shared_gate"),
py::arg("act_limit")
py::arg("act_limit"),
py::arg("gates"),
py::arg("ups"),
py::arg("downs"),
py::arg("gu_trellis_ptr"),
py::arg("gu_suh_ptr"),
py::arg("gu_svh_ptr")
)
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1)
.def("run_single_expert", &BC_BlockSparseMLP::run_single_expert)
.def("run_single_expert_dq", &BC_BlockSparseMLP::run_single_expert_dq);

View File

@@ -36,6 +36,52 @@ void had_ff_r_128_kernel
had_ff_r_128_inner(input_ptr, output_ptr, pre_scale, post_scale, r_scale);
}
__global__ __launch_bounds__(32)
void had_hf_r_128_dual_kernel
(
const half* __restrict__ input1_ptr,
half* __restrict__ output1_ptr,
const half* __restrict__ pre1_scale,
const half* __restrict__ post1_scale,
const half* __restrict__ input2_ptr,
half* __restrict__ output2_ptr,
const half* __restrict__ pre2_scale,
const half* __restrict__ post2_scale,
const float r_scale
)
{
input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
had_hf_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale);
input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
had_hf_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale);
}
__global__ __launch_bounds__(32)
void had_ff_r_128_dual_kernel
(
const float* __restrict__ input1_ptr,
float* __restrict__ output1_ptr,
const half* __restrict__ pre1_scale,
const half* __restrict__ post1_scale,
const float* __restrict__ input2_ptr,
float* __restrict__ output2_ptr,
const half* __restrict__ pre2_scale,
const half* __restrict__ post2_scale,
const float r_scale
)
{
input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
had_ff_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale);
input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
had_ff_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale);
}
/*
Compute y = (x.view(-1, 128) @ had_128).view(x.shape)
Works inplace if y == x
@@ -94,4 +140,73 @@ void had_r_128
}
else TORCH_CHECK(false, "unsupported datatype");
}
}
void had_r_128_dual
(
const at::Tensor& input1,
const at::Tensor& output1,
const c10::optional<at::Tensor>& pre_scale1,
const c10::optional<at::Tensor>& post_scale1,
const at::Tensor& input2,
const at::Tensor& output2,
const c10::optional<at::Tensor>& pre_scale2,
const c10::optional<at::Tensor>& post_scale2,
const float scale
)
{
const at::cuda::OptionalCUDAGuard device_guard(input1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_SHAPES_FULL(input1, output1);
TORCH_CHECK_SHAPES_FULL(input1, input2);
TORCH_CHECK_SHAPES_FULL(output1, output2);
TORCH_CHECK_DIM(input1, 2);
TORCH_CHECK_DIV(input1, 1, 128);
int rows = input1.size(0);
int cols = input1.size(1);
int blocks = cols / 128;
float r_scale = scale * 0.088388347648f; // scale / sqrt(128)
dim3 blockDim(32);
dim3 gridDim(rows, blocks);
if (input1.dtype() == at::kHalf)
{
TORCH_CHECK_DTYPE(output1, kHalf);
had_hf_r_128_dual_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) input1.data_ptr(),
(half*) output1.data_ptr(),
(const half*) OPTPTR(pre_scale1),
(const half*) OPTPTR(post_scale1),
(const half*) input2.data_ptr(),
(half*) output2.data_ptr(),
(const half*) OPTPTR(pre_scale2),
(const half*) OPTPTR(post_scale2),
r_scale
);
cuda_check(cudaPeekAtLastError());
}
else if (input1.dtype() == at::kFloat)
{
TORCH_CHECK_DTYPE(output1, kFloat);
had_ff_r_128_dual_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) input1.data_ptr(),
(float*) output1.data_ptr(),
(const half*) OPTPTR(pre_scale1),
(const half*) OPTPTR(post_scale1),
(const float*) input2.data_ptr(),
(float*) output2.data_ptr(),
(const half*) OPTPTR(pre_scale2),
(const half*) OPTPTR(post_scale2),
r_scale
);
cuda_check(cudaPeekAtLastError());
}
else TORCH_CHECK(false, "unsupported datatype");
}

View File

@@ -10,3 +10,16 @@ void had_r_128
const c10::optional<at::Tensor>& post_scale,
const float scale
);
void had_r_128_dual
(
const at::Tensor& input1,
const at::Tensor& output1,
const c10::optional<at::Tensor>& pre_scale1,
const c10::optional<at::Tensor>& post_scale1,
const at::Tensor& input2,
const at::Tensor& output2,
const c10::optional<at::Tensor>& pre_scale2,
const c10::optional<at::Tensor>& post_scale2,
const float scale
);

View File

@@ -11,7 +11,9 @@ from dataclasses import dataclass
from .mlp import MLP, GatedMLP
from ..model.model_tp_alloc import TPAllocation
from ..util import profile_opt
from ..util.tensor import g_tensor_cache
TEMP_ROWS = 32
@dataclass
class RoutingCFG:
@@ -141,6 +143,7 @@ class ExpertsCFG:
interm_u: torch.Tensor
interm_a: torch.Tensor
out_d: torch.Tensor
out_d2: torch.Tensor
min_expert: int
max_expert: int
@@ -208,6 +211,9 @@ class BlockSparseMLP(Module):
self.n_group = n_group
self.topk_group = topk_group
assert num_experts_per_tok <= TEMP_ROWS, \
f"Too many experts per token, max supported is {TEMP_ROWS}"
if routing_gate is None and key_routing_gate is None:
self.routing_gate = None
elif routing_gate is None:
@@ -377,66 +383,86 @@ class BlockSparseMLP(Module):
self.multi_up = MultiLinear(self.device, self.ups)
self.multi_down = MultiLinear(self.device, self.downs)
yh = torch.empty(
(self.num_experts_per_tok, 1, self.hidden_size),
dtype = torch.half,
device = self.device
)
interm_g = torch.empty(
(self.num_experts_per_tok, 1, self.intermediate_size),
dtype = self.interm_dtype,
device = self.device
)
interm_u = torch.empty_like(interm_g)
interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u
out_d = torch.empty(
(self.num_experts_per_tok, 1, self.hidden_size),
dtype = self.out_dtype or torch.half,
device = self.device
)
# Temp buffers
numex = self.num_experts_per_tok
H = self.hidden_size
I = self.intermediate_size
device = self.device
temp_hidden = g_tensor_cache.get(device, (TEMP_ROWS * 2, H), torch.half, "temp_hidden")
temp_interm = g_tensor_cache.get(device, (TEMP_ROWS * 2, I), self.interm_dtype, "temp_interm")
temp_activa = g_tensor_cache.get(device, (TEMP_ROWS, I), torch.half, "temp_activa")
temp_output = g_tensor_cache.get(device, (TEMP_ROWS, H), self.out_dtype or torch.half, "temp_output")
yh = temp_hidden[:numex].view(numex, 1, H)
interm_g = temp_interm[:numex].view(numex, 1, I)
interm_u = temp_interm[numex:numex*2].view(numex, 1, I)
interm_a = temp_activa[:numex].view(numex, 1, I)
yh2 = temp_hidden
interm_gu = temp_interm
out_d = temp_output[:numex].view(numex, 1, H)
out_d2 = temp_output
# Expert interval for split module (-1, -1) indicate no split
mine, maxe = self.routing_first, self.routing_last
if mine is None or maxe - mine == self.num_experts:
mine, maxe = -1, -1
self.experts_cfg = ExpertsCFG(
cfg = ExpertsCFG(
yh = yh,
interm_g = interm_g,
interm_u = interm_u,
interm_a = interm_a,
out_d = out_d,
out_d2 = out_d2,
min_expert = mine,
max_expert = maxe,
)
self.experts_cfg = cfg
cfg = self.experts_cfg
if self.is_quantized:
sh_exp = None
# Embed bound classes for shared experts and shared gate
sh_exp_bc = None
sh_exp_t = None
sh_gate = None
sh_gate_bc = None
sh_gate_t = None
self.bc_sh_exp = False
if self.shared_experts and isinstance(self.shared_experts, GatedMLP) and self.shared_experts.bc is not None:
self.bc_sh_exp = True
sh_exp = self.shared_experts.bc
sh_exp_t = torch.empty(
(1, 1, self.hidden_size),
dtype = self.out_dtype or torch.half,
device = self.device
)
sh_exp_bc = self.shared_experts.bc
sh_exp_t = torch.empty((1, 1, H), dtype = self.out_dtype or torch.half, device = self.device)
if self.shared_gate:
assert self.shared_gate.quant_type == "fp16"
sh_gate = self.shared_gate.inner.bc
sh_gate_bc = self.shared_gate.inner.bc
sh_gate_t = torch.empty((1, 1, 1), dtype = self.shared_gate.out_dtype, device = self.device)
g_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.gates])
u_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.ups])
g_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.gates])
u_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.ups])
g_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.gates])
u_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.ups])
gu_trellis_ptr = torch.stack((g_trellis_ptr, u_trellis_ptr), dim = 0).T.contiguous().to(self.device)
gu_suh_ptr = torch.stack((g_suh_ptr, u_suh_ptr), dim = 0).T.contiguous().to(self.device)
gu_svh_ptr = torch.stack((g_svh_ptr, u_svh_ptr), dim = 0).T.contiguous().to(self.device)
dq_temp_up = g_tensor_cache.get(device, (H, I), torch.half, "dq_temp")
dq_temp_down = dq_temp_up.view(I, H)
self.bc = ext.BC_BlockSparseMLP(
yh2,
cfg.yh,
interm_gu,
cfg.interm_g,
cfg.interm_u,
cfg.interm_a,
cfg.out_d,
cfg.out_d2,
sh_exp_t,
sh_gate_t,
dq_temp_up,
dq_temp_down,
cfg.min_expert,
cfg.max_expert,
self.multi_gate.ptrs_trellis,
@@ -459,9 +485,15 @@ class BlockSparseMLP(Module):
self.multi_down.mul1,
self.activation_fn == "silu",
self.activation_fn == "gelu",
sh_exp,
sh_gate,
self.act_limit
sh_exp_bc,
sh_gate_bc,
self.act_limit,
[x.inner.bc for x in self.gates],
[x.inner.bc for x in self.ups],
[x.inner.bc for x in self.downs],
gu_trellis_ptr,
gu_suh_ptr,
gu_svh_ptr
)
@@ -545,45 +577,101 @@ class BlockSparseMLP(Module):
if self.intermediate_size == 0 or self.num_local_experts == 0:
final_hidden_states = torch.zeros_like(x, dtype = self.out_dtype)
# Torch path
# Torch/C++ path
elif bsz >= self.f_threshold or not self.is_quantized:
final_hidden_states = torch.zeros_like(y, dtype = self.out_dtype)
if self.routing_device is None or self.num_local_experts == self.num_experts:
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts)
else:
# TODO: profile, maybe optimize
selected_experts -= self.routing_first
invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts)
shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1)
expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:]
# routing_weights[invalid] = 0.0
# if self.routing_device is None or self.num_local_experts == self.num_experts:
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts)
# else:
# selected_experts -= self.routing_first
# invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts)
# shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1)
# expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:]
if self.num_local_experts is None or self.num_local_experts > 0:
num_ex = self.num_local_experts or self.num_experts
expert_count = expert_mask.view(-1, num_ex).sum(dim = 0).cpu()
expert_mask = expert_mask.permute(2, 1, 0)
def mlp(exp_i, xc):
g = self.gates[exp_i].forward(xc, params)
u = self.ups[exp_i].forward(xc, params)
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
self.activation_fn_call(g, u, a, self.act_limit)
return self.downs[exp_i].forward(a, params)
num_tokens, top_k = selected_experts.shape
E = self.num_local_experts
# Flatten assignments
flat_expert_global = selected_experts.reshape(-1) # [num_tokens * top_k]
flat_weight = routing_weights.reshape(-1) # [num_tokens * top_k]
# Token indices corresponding to each flattened assignment
flat_token = torch.arange(num_tokens, device = y.device)
flat_token = flat_token.repeat_interleave(top_k) # [num_tokens * top_k]
if self.routing_device is None or self.num_local_experts == self.num_experts:
flat_expert_local = flat_expert_global
else:
flat_expert_local = flat_expert_global - self.routing_first
valid = (flat_expert_local >= 0) & (flat_expert_local < E)
flat_expert_local = torch.where(valid, flat_expert_local, torch.full_like(flat_expert_local, E))
# Group once by local expert id (including sentinel for expert-P mode)
order = flat_expert_local.argsort()
local_sorted = flat_expert_local[order]
token_sorted = flat_token[order]
weight_sorted = flat_weight[order]
# Count how many assignments per expert
expert_count = torch.bincount(local_sorted, minlength = E + 1)
expert_ptr = torch.empty(E + 2, device = y.device, dtype = torch.long)
expert_ptr[0] = 0
expert_ptr[1:] = expert_count.cumsum(0)
expert_ptr = expert_ptr.tolist()
for expert_idx in range(num_ex):
if expert_count[expert_idx] == 0:
start = expert_ptr[expert_idx]
end = expert_ptr[expert_idx + 1]
count = end - start
if count == 0:
continue
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = y[None, top_x].reshape(-1, self.hidden_size)
current_state = mlp(expert_idx, current_state) * routing_weights[top_x, idx, None]
top_x = token_sorted[start:end]
w = weight_sorted[start:end].unsqueeze(1)
current_state = y.index_select(0, top_x)
if self.bc is not None:
if count <= TEMP_ROWS:
self.bc.run_single_expert(current_state, expert_idx)
current_state = self.experts_cfg.out_d2[:count]
else:
out_state = torch.empty(
(count, self.hidden_size),
dtype = self.out_dtype or torch.half,
device = self.device
)
interm = torch.empty(
(2, count, self.intermediate_size),
dtype = self.interm_dtype,
device = self.device
)
interm_a = interm[0] if self.interm_dtype == torch.half else torch.empty_like(interm[0], dtype = torch.half)
yh = torch.empty((2, count, self.hidden_size), dtype = torch.half, device = self.device)
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
current_state = out_state
else:
def mlp(exp_i, xc):
g = self.gates[exp_i].forward(xc, params)
u = self.ups[exp_i].forward(xc, params)
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
self.activation_fn_call(g, u, a, self.act_limit)
return self.downs[exp_i].forward(a, params)
current_state = mlp(expert_idx, current_state)
current_state.mul_(w)
final_hidden_states.index_add_(0, top_x, current_state)
final_hidden_states = final_hidden_states.reshape(x.shape)
final_hidden_states = to2(final_hidden_states, out_dtype, self.out_dtype)
# Fused path
# Fused path, few tokens
elif bsz > 1:
final_hidden_states = torch.empty_like(y, dtype = self.out_dtype)
@@ -592,12 +680,6 @@ class BlockSparseMLP(Module):
selected_experts = selected_experts.unsqueeze(1)
routing_weights = routing_weights.unsqueeze(1)
# yh = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = torch.half, device = self.device)
# interm_g = torch.empty((bsz, self.num_experts_per_tok, 1, self.intermediate_size), dtype = self.interm_dtype, device = self.device)
# interm_u = torch.empty_like(interm_g)
# interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u
# out_d = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = self.out_dtype or torch.half, device = self.device)
cfg = self.experts_cfg
mine, maxe = self.routing_first, self.routing_last