From 766a28dc606670c3d9701394543cb8019d98fec2 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:21:34 +0100 Subject: [PATCH] BlockSparseMLP: Improved batch routing --- .../libtorch/blocksparse_mlp.cpp | 258 +++++++++++++++++- .../exllamav3_ext/libtorch/blocksparse_mlp.h | 104 ++++--- .../libtorch/blocksparse_mlp_bc.h | 30 +- exllamav3/exllamav3_ext/quant/hadamard.cu | 117 +++++++- exllamav3/exllamav3_ext/quant/hadamard.cuh | 13 + exllamav3/modules/block_sparse_mlp.py | 202 ++++++++++---- 6 files changed, 615 insertions(+), 109 deletions(-) diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp index d150a97..90c7460 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp @@ -6,6 +6,8 @@ #include "../util.h" #include "../hgemm.cuh" #include "../quant/exl3_gemm.cuh" +#include "../quant/hadamard.cuh" +#include "../quant/reconstruct.cuh" #include "../activation.cuh" #include "../add.cuh" @@ -69,7 +71,7 @@ void BC_BlockSparseMLP::run_bsz1_gr Graph* graph ) { - py::gil_scoped_release _; + //py::gil_scoped_release _; const at::Tensor& yi = y.unsqueeze(0); exl3_mgemm_gr @@ -213,4 +215,258 @@ void BC_BlockSparseMLP::run_bsz1 graph_bsz1.launch(args, stream); #endif + #undef USE_GRAPH +} + +BC_BlockSparseMLP::BC_BlockSparseMLP +( + at::Tensor _yh2, + at::Tensor _yh, + at::Tensor _interm_gu, + at::Tensor _interm_g, + at::Tensor _interm_u, + at::Tensor _interm_a, + at::Tensor _out_d, + at::Tensor _out_d2, + c10::optional _out_d_sh, + c10::optional _z, + at::Tensor _dq_temp_up, + at::Tensor _dq_temp_down, + int _min_expert, + int _max_expert, + at::Tensor _gate_ptrs_trellis, + at::Tensor _gate_ptrs_suh, + at::Tensor _gate_ptrs_svh, + int _gate_K, + bool _gate_mcg, + bool _gate_mul1, + at::Tensor _up_ptrs_trellis, + at::Tensor _up_ptrs_suh, + at::Tensor _up_ptrs_svh, + int _up_K, + bool _up_mcg, + bool _up_mul1, + at::Tensor _down_ptrs_trellis, + at::Tensor _down_ptrs_suh, + at::Tensor _down_ptrs_svh, + int _down_K, + bool _down_mcg, + bool _down_mul1, + bool _act_silu, + bool _act_gelu, + std::shared_ptr _shared_experts, + std::shared_ptr _shared_gate, + float _act_limit, + std::vector> _gates, + std::vector> _ups, + std::vector> _downs, + at::Tensor _gu_trellis_ptr, + at::Tensor _gu_suh_ptr, + at::Tensor _gu_svh_ptr +) : + yh2 (std::move(_yh2)), + yh (std::move(_yh)), + interm_gu (std::move(_interm_gu)), + interm_g (std::move(_interm_g)), + interm_u (std::move(_interm_u)), + interm_a (std::move(_interm_a)), + out_d (std::move(_out_d)), + out_d2 (std::move(_out_d2)), + out_d_sh (std::move(_out_d_sh)), + z (std::move(_z)), + dq_temp_up (std::move(_dq_temp_up)), + dq_temp_down (std::move(_dq_temp_down)), + min_expert (_min_expert), + max_expert (_max_expert), + gate_ptrs_trellis (std::move(_gate_ptrs_trellis)), + gate_ptrs_suh (std::move(_gate_ptrs_suh)), + gate_ptrs_svh (std::move(_gate_ptrs_svh)), + gate_K (_gate_K), + gate_mcg (_gate_mcg), + gate_mul1 (_gate_mul1), + up_ptrs_trellis (std::move(_up_ptrs_trellis)), + up_ptrs_suh (std::move(_up_ptrs_suh)), + up_ptrs_svh (std::move(_up_ptrs_svh)), + up_K (_up_K), + up_mcg (_up_mcg), + up_mul1 (_up_mul1), + down_ptrs_trellis (std::move(_down_ptrs_trellis)), + down_ptrs_suh (std::move(_down_ptrs_suh)), + down_ptrs_svh (std::move(_down_ptrs_svh)), + down_K (_down_K), + down_mcg (_down_mcg), + down_mul1 (_down_mul1), + act_silu (_act_silu), + act_gelu (_act_gelu), + shared_experts (_shared_experts), + shared_gate (_shared_gate), + act_limit (_act_limit), + gates (_gates), + ups (_ups), + downs (_downs), + gu_trellis_ptr (_gu_trellis_ptr), + gu_suh_ptr (_gu_suh_ptr), + gu_svh_ptr (_gu_svh_ptr) +{ + gate_ptrs_trellis_cpu = gate_ptrs_trellis.cpu(); + gate_ptrs_suh_cpu = gate_ptrs_suh.cpu(); + gate_ptrs_svh_cpu = gate_ptrs_svh.cpu(); + up_ptrs_trellis_cpu = up_ptrs_trellis.cpu(); + up_ptrs_suh_cpu = up_ptrs_suh.cpu(); + up_ptrs_svh_cpu = up_ptrs_svh.cpu(); + down_ptrs_trellis_cpu = down_ptrs_trellis.cpu(); + down_ptrs_suh_cpu = down_ptrs_suh.cpu(); + down_ptrs_svh_cpu = down_ptrs_svh.cpu(); + + max_experts_per_token = interm_g.size(0); + max_tokens_per_expert = max_experts_per_token; + + for (int i = 0; i < max_tokens_per_expert; ++i) + { + interm_g_single.push_back(interm_g.squeeze(1).slice(0, 0, i + 1)); + interm_u_single.push_back(interm_u.squeeze(1).slice(0, 0, i + 1)); + interm_a_single.push_back(interm_a.squeeze(1).slice(0, 0, i + 1)); + out_d_single.push_back(out_d.squeeze(1).slice(0, 0, i + 1)); + } + + TORCH_CHECK(max_expert <= MAX_EXPERTS, "BC_BlockSparseMLP: Too many experts"); + + use_mgemm = gate_K == up_K; +} + +void BC_BlockSparseMLP::run_single_expert +( + const at::Tensor& y, + const int expert_idx +) +{ + int bsz = y.size(0); + + at::Tensor ai = interm_a.slice(0, 0, bsz); + at::Tensor oi = out_d2.slice(0, 0, bsz); + + if (use_mgemm) + { + at::Tensor yb = y.unsqueeze(0); + at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)}); + at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)}); + + exl3_mgemm + ( + yb, + gu_trellis_ptr[expert_idx], + gui, + gu_suh_ptr[expert_idx], + yh2, + gu_svh_ptr[expert_idx], + c10::nullopt, + c10::nullopt, + gate_K, + -1, + gate_mcg, + gate_mul1, + -1, + -1, + 0 + ); + + at::Tensor gi = gui[0]; + at::Tensor ui = gui[1]; + + if (act_silu) + silu_mul(gi, ui, ai, act_limit); + else if (act_gelu) + gelu_mul(gi, ui, ai, act_limit); + } + else + { + at::Tensor gi = interm_gu.slice(0, 0, bsz); + at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2); + + exl3_gemm + ( + y, + gates[expert_idx]->trellis, + gi, + gates[expert_idx]->suh, + yh, + gates[expert_idx]->svh, + -1, + gate_mcg, + gate_mul1, + 0 + ); + + exl3_gemm + ( + y, + ups[expert_idx]->trellis, + ui, + ups[expert_idx]->suh, + yh, + ups[expert_idx]->svh, + -1, + up_mcg, + up_mul1, + 0 + ); + + if (act_silu) + silu_mul(gi, ui, ai, act_limit); + else if (act_gelu) + gelu_mul(gi, ui, ai, act_limit); + } + + exl3_gemm + ( + ai, + downs[expert_idx]->trellis, + oi, + downs[expert_idx]->suh, + ai, + downs[expert_idx]->svh, + -1, + down_mcg, + down_mul1, + 0 + ); +} + +void BC_BlockSparseMLP::run_single_expert_dq +( + const at::Tensor& y, + const int expert_idx, + at::Tensor& yh, + at::Tensor& interm, + at::Tensor& interm_a, + at::Tensor& out +) +{ + int bsz = y.size(0); + + at::Tensor yh1 = yh[0]; + at::Tensor yh2 = yh[0]; + at::Tensor interm1 = interm[0]; + at::Tensor interm2 = interm[1]; + + had_r_128_dual(y, yh1, gates[expert_idx]->suh, c10::nullopt, + y, yh2, ups[expert_idx]->suh, c10::nullopt, 1.0); + + reconstruct(dq_temp_up, gates[expert_idx]->trellis, gate_K, gate_mcg, gate_mul1); + hgemm(yh1, dq_temp_up, interm1); + reconstruct(dq_temp_up, ups[expert_idx]->trellis, up_K, up_mcg, up_mul1); + hgemm(yh2, dq_temp_up, interm2); + + had_r_128_dual(interm1, interm1, c10::nullopt, gates[expert_idx]->svh, + interm2, interm2, c10::nullopt, ups[expert_idx]->svh, 1.0); + + if (act_silu) + silu_mul(interm1, interm2, interm_a, act_limit); + else if (act_gelu) + gelu_mul(interm1, interm2, interm_a, act_limit); + + had_r_128(interm_a, interm_a, downs[expert_idx]->suh, c10::nullopt, 1.0); + reconstruct(dq_temp_down, downs[expert_idx]->trellis, down_K, down_mcg, down_mul1); + hgemm(interm_a, dq_temp_down, out); + had_r_128(out, out, c10::nullopt, downs[expert_idx]->svh, 1.0); } diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h index 0a02049..29717ad 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h @@ -9,6 +9,8 @@ namespace py = pybind11; #include "linear.h" #include "../graph.cuh" +#define MAX_EXPERTS 512 + std::tuple blocksparse_mlp_routing( int bsz, const py::object& cfg, @@ -18,30 +20,35 @@ std::tuple blocksparse_mlp_routing( struct BC_BlockSparseMLP { + at::Tensor yh2; at::Tensor yh; + at::Tensor interm_gu; at::Tensor interm_g; at::Tensor interm_u; at::Tensor interm_a; at::Tensor out_d; + at::Tensor out_d2; c10::optional out_d_sh; c10::optional z; + at::Tensor dq_temp_up; + at::Tensor dq_temp_down; int min_expert; int max_expert; - at::Tensor gate_ptrs_trellis; - at::Tensor gate_ptrs_suh; - at::Tensor gate_ptrs_svh; + at::Tensor gate_ptrs_trellis; at::Tensor gate_ptrs_trellis_cpu; + at::Tensor gate_ptrs_suh; at::Tensor gate_ptrs_suh_cpu; + at::Tensor gate_ptrs_svh; at::Tensor gate_ptrs_svh_cpu; int gate_K; bool gate_mcg; bool gate_mul1; - at::Tensor up_ptrs_trellis; - at::Tensor up_ptrs_suh; - at::Tensor up_ptrs_svh; + at::Tensor up_ptrs_trellis; at::Tensor up_ptrs_trellis_cpu; + at::Tensor up_ptrs_suh; at::Tensor up_ptrs_suh_cpu; + at::Tensor up_ptrs_svh; at::Tensor up_ptrs_svh_cpu; int up_K; bool up_mcg; bool up_mul1; - at::Tensor down_ptrs_trellis; - at::Tensor down_ptrs_suh; - at::Tensor down_ptrs_svh; + at::Tensor down_ptrs_trellis; at::Tensor down_ptrs_trellis_cpu; + at::Tensor down_ptrs_suh; at::Tensor down_ptrs_suh_cpu; + at::Tensor down_ptrs_svh; at::Tensor down_ptrs_svh_cpu; int down_K; bool down_mcg; bool down_mul1; @@ -50,18 +57,38 @@ struct BC_BlockSparseMLP std::shared_ptr shared_experts; std::shared_ptr shared_gate; float act_limit; + std::vector> gates; + std::vector> ups; + std::vector> downs; + at::Tensor gu_trellis_ptr; + at::Tensor gu_suh_ptr; + at::Tensor gu_svh_ptr; + + int max_experts_per_token; + int max_tokens_per_expert; + std::vector interm_g_single; + std::vector interm_u_single; + std::vector interm_a_single; + std::vector out_d_single; + + bool use_mgemm; Graph graph_bsz1; BC_BlockSparseMLP ( + at::Tensor _yh2, at::Tensor _yh, + at::Tensor _interm_gu, at::Tensor _interm_g, at::Tensor _interm_u, at::Tensor _interm_a, at::Tensor _out_d, + at::Tensor _out_d2, c10::optional _out_d_sh, c10::optional _z, + at::Tensor _dq_temp_up, + at::Tensor _dq_temp_down, int _min_expert, int _max_expert, at::Tensor _gate_ptrs_trellis, @@ -86,41 +113,14 @@ struct BC_BlockSparseMLP bool _act_gelu, std::shared_ptr _shared_experts, std::shared_ptr _shared_gate, - float _act_limit - ) : - yh (std::move(_yh)), - interm_g (std::move(_interm_g)), - interm_u (std::move(_interm_u)), - interm_a (std::move(_interm_a)), - out_d (std::move(_out_d)), - out_d_sh (std::move(_out_d_sh)), - z (std::move(_z)), - min_expert (_min_expert), - max_expert (_max_expert), - gate_ptrs_trellis (std::move(_gate_ptrs_trellis)), - gate_ptrs_suh (std::move(_gate_ptrs_suh)), - gate_ptrs_svh (std::move(_gate_ptrs_svh)), - gate_K (_gate_K), - gate_mcg (_gate_mcg), - gate_mul1 (_gate_mul1), - up_ptrs_trellis (std::move(_up_ptrs_trellis)), - up_ptrs_suh (std::move(_up_ptrs_suh)), - up_ptrs_svh (std::move(_up_ptrs_svh)), - up_K (_up_K), - up_mcg (_up_mcg), - up_mul1 (_up_mul1), - down_ptrs_trellis (std::move(_down_ptrs_trellis)), - down_ptrs_suh (std::move(_down_ptrs_suh)), - down_ptrs_svh (std::move(_down_ptrs_svh)), - down_K (_down_K), - down_mcg (_down_mcg), - down_mul1 (_down_mul1), - act_silu (_act_silu), - act_gelu (_act_gelu), - shared_experts (_shared_experts), - shared_gate (_shared_gate), - act_limit (_act_limit) - {} + float _act_limit, + std::vector> _gates, + std::vector> _ups, + std::vector> _downs, + at::Tensor _gu_trellis_ptr, + at::Tensor _gu_suh_ptr, + at::Tensor _gu_svh_ptr + ); void run_bsz1_gr ( @@ -136,4 +136,20 @@ struct BC_BlockSparseMLP at::Tensor& selected_experts, at::Tensor& routing_weights ); + + void run_single_expert + ( + const at::Tensor& y, + const int expert_idx + ); + + void run_single_expert_dq + ( + const at::Tensor& y, + const int expert_idx, + at::Tensor& yh, + at::Tensor& interm, + at::Tensor& interm_a, + at::Tensor& out + ); }; \ No newline at end of file diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h index 3ee9738..e2a0363 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h @@ -6,8 +6,13 @@ py::class_>(m, "BC_BlockSp at::Tensor, at::Tensor, at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, c10::optional, c10::optional, + at::Tensor, + at::Tensor, int, int, at::Tensor, @@ -32,15 +37,26 @@ py::class_>(m, "BC_BlockSp bool, std::shared_ptr, std::shared_ptr, - float + float, + std::vector>, + std::vector>, + std::vector>, + at::Tensor, + at::Tensor, + at::Tensor >(), + py::arg("yh2"), py::arg("yh"), + py::arg("interm_gu"), py::arg("interm_g"), py::arg("interm_u"), py::arg("interm_a"), py::arg("out_d"), + py::arg("out_d2"), py::arg("out_d_sh"), py::arg("z"), + py::arg("dq_temp_up"), + py::arg("dq_temp_down"), py::arg("min_expert"), py::arg("max_expert"), py::arg("gate_ptrs_trellis"), @@ -65,6 +81,14 @@ py::class_>(m, "BC_BlockSp py::arg("act_gelu"), py::arg("shared_experts"), py::arg("shared_gate"), - py::arg("act_limit") + py::arg("act_limit"), + py::arg("gates"), + py::arg("ups"), + py::arg("downs"), + py::arg("gu_trellis_ptr"), + py::arg("gu_suh_ptr"), + py::arg("gu_svh_ptr") ) -.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1); +.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1) +.def("run_single_expert", &BC_BlockSparseMLP::run_single_expert) +.def("run_single_expert_dq", &BC_BlockSparseMLP::run_single_expert_dq); diff --git a/exllamav3/exllamav3_ext/quant/hadamard.cu b/exllamav3/exllamav3_ext/quant/hadamard.cu index 277b18a..fd65955 100644 --- a/exllamav3/exllamav3_ext/quant/hadamard.cu +++ b/exllamav3/exllamav3_ext/quant/hadamard.cu @@ -36,6 +36,52 @@ void had_ff_r_128_kernel had_ff_r_128_inner(input_ptr, output_ptr, pre_scale, post_scale, r_scale); } +__global__ __launch_bounds__(32) +void had_hf_r_128_dual_kernel +( + const half* __restrict__ input1_ptr, + half* __restrict__ output1_ptr, + const half* __restrict__ pre1_scale, + const half* __restrict__ post1_scale, + const half* __restrict__ input2_ptr, + half* __restrict__ output2_ptr, + const half* __restrict__ pre2_scale, + const half* __restrict__ post2_scale, + const float r_scale +) +{ + input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_hf_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale); + + input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_hf_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale); +} + +__global__ __launch_bounds__(32) +void had_ff_r_128_dual_kernel +( + const float* __restrict__ input1_ptr, + float* __restrict__ output1_ptr, + const half* __restrict__ pre1_scale, + const half* __restrict__ post1_scale, + const float* __restrict__ input2_ptr, + float* __restrict__ output2_ptr, + const half* __restrict__ pre2_scale, + const half* __restrict__ post2_scale, + const float r_scale +) +{ + input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_ff_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale); + + input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_ff_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale); +} + /* Compute y = (x.view(-1, 128) @ had_128).view(x.shape) Works inplace if y == x @@ -94,4 +140,73 @@ void had_r_128 } else TORCH_CHECK(false, "unsupported datatype"); -} \ No newline at end of file +} + +void had_r_128_dual +( + const at::Tensor& input1, + const at::Tensor& output1, + const c10::optional& pre_scale1, + const c10::optional& post_scale1, + const at::Tensor& input2, + const at::Tensor& output2, + const c10::optional& pre_scale2, + const c10::optional& post_scale2, + const float scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES_FULL(input1, output1); + TORCH_CHECK_SHAPES_FULL(input1, input2); + TORCH_CHECK_SHAPES_FULL(output1, output2); + TORCH_CHECK_DIM(input1, 2); + TORCH_CHECK_DIV(input1, 1, 128); + int rows = input1.size(0); + int cols = input1.size(1); + + int blocks = cols / 128; + float r_scale = scale * 0.088388347648f; // scale / sqrt(128) + + dim3 blockDim(32); + dim3 gridDim(rows, blocks); + + if (input1.dtype() == at::kHalf) + { + TORCH_CHECK_DTYPE(output1, kHalf); + had_hf_r_128_dual_kernel<<>> + ( + (const half*) input1.data_ptr(), + (half*) output1.data_ptr(), + (const half*) OPTPTR(pre_scale1), + (const half*) OPTPTR(post_scale1), + (const half*) input2.data_ptr(), + (half*) output2.data_ptr(), + (const half*) OPTPTR(pre_scale2), + (const half*) OPTPTR(post_scale2), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else if (input1.dtype() == at::kFloat) + { + TORCH_CHECK_DTYPE(output1, kFloat); + had_ff_r_128_dual_kernel<<>> + ( + (const float*) input1.data_ptr(), + (float*) output1.data_ptr(), + (const half*) OPTPTR(pre_scale1), + (const half*) OPTPTR(post_scale1), + (const float*) input2.data_ptr(), + (float*) output2.data_ptr(), + (const half*) OPTPTR(pre_scale2), + (const half*) OPTPTR(post_scale2), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else TORCH_CHECK(false, "unsupported datatype"); +} diff --git a/exllamav3/exllamav3_ext/quant/hadamard.cuh b/exllamav3/exllamav3_ext/quant/hadamard.cuh index fff4737..5c25ff3 100644 --- a/exllamav3/exllamav3_ext/quant/hadamard.cuh +++ b/exllamav3/exllamav3_ext/quant/hadamard.cuh @@ -10,3 +10,16 @@ void had_r_128 const c10::optional& post_scale, const float scale ); + +void had_r_128_dual +( + const at::Tensor& input1, + const at::Tensor& output1, + const c10::optional& pre_scale1, + const c10::optional& post_scale1, + const at::Tensor& input2, + const at::Tensor& output2, + const c10::optional& pre_scale2, + const c10::optional& post_scale2, + const float scale +); diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index f4fc40b..d6cdabc 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -11,7 +11,9 @@ from dataclasses import dataclass from .mlp import MLP, GatedMLP from ..model.model_tp_alloc import TPAllocation from ..util import profile_opt +from ..util.tensor import g_tensor_cache +TEMP_ROWS = 32 @dataclass class RoutingCFG: @@ -141,6 +143,7 @@ class ExpertsCFG: interm_u: torch.Tensor interm_a: torch.Tensor out_d: torch.Tensor + out_d2: torch.Tensor min_expert: int max_expert: int @@ -208,6 +211,9 @@ class BlockSparseMLP(Module): self.n_group = n_group self.topk_group = topk_group + assert num_experts_per_tok <= TEMP_ROWS, \ + f"Too many experts per token, max supported is {TEMP_ROWS}" + if routing_gate is None and key_routing_gate is None: self.routing_gate = None elif routing_gate is None: @@ -377,66 +383,86 @@ class BlockSparseMLP(Module): self.multi_up = MultiLinear(self.device, self.ups) self.multi_down = MultiLinear(self.device, self.downs) - yh = torch.empty( - (self.num_experts_per_tok, 1, self.hidden_size), - dtype = torch.half, - device = self.device - ) - interm_g = torch.empty( - (self.num_experts_per_tok, 1, self.intermediate_size), - dtype = self.interm_dtype, - device = self.device - ) - interm_u = torch.empty_like(interm_g) - interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u - out_d = torch.empty( - (self.num_experts_per_tok, 1, self.hidden_size), - dtype = self.out_dtype or torch.half, - device = self.device - ) + # Temp buffers + numex = self.num_experts_per_tok + H = self.hidden_size + I = self.intermediate_size + device = self.device + temp_hidden = g_tensor_cache.get(device, (TEMP_ROWS * 2, H), torch.half, "temp_hidden") + temp_interm = g_tensor_cache.get(device, (TEMP_ROWS * 2, I), self.interm_dtype, "temp_interm") + temp_activa = g_tensor_cache.get(device, (TEMP_ROWS, I), torch.half, "temp_activa") + temp_output = g_tensor_cache.get(device, (TEMP_ROWS, H), self.out_dtype or torch.half, "temp_output") + + yh = temp_hidden[:numex].view(numex, 1, H) + interm_g = temp_interm[:numex].view(numex, 1, I) + interm_u = temp_interm[numex:numex*2].view(numex, 1, I) + interm_a = temp_activa[:numex].view(numex, 1, I) + yh2 = temp_hidden + interm_gu = temp_interm + out_d = temp_output[:numex].view(numex, 1, H) + out_d2 = temp_output + + # Expert interval for split module (-1, -1) indicate no split mine, maxe = self.routing_first, self.routing_last if mine is None or maxe - mine == self.num_experts: mine, maxe = -1, -1 - self.experts_cfg = ExpertsCFG( + + cfg = ExpertsCFG( yh = yh, interm_g = interm_g, interm_u = interm_u, interm_a = interm_a, out_d = out_d, + out_d2 = out_d2, min_expert = mine, max_expert = maxe, ) + self.experts_cfg = cfg - cfg = self.experts_cfg if self.is_quantized: - sh_exp = None + # Embed bound classes for shared experts and shared gate + sh_exp_bc = None sh_exp_t = None - sh_gate = None + sh_gate_bc = None sh_gate_t = None self.bc_sh_exp = False if self.shared_experts and isinstance(self.shared_experts, GatedMLP) and self.shared_experts.bc is not None: self.bc_sh_exp = True - sh_exp = self.shared_experts.bc - sh_exp_t = torch.empty( - (1, 1, self.hidden_size), - dtype = self.out_dtype or torch.half, - device = self.device - ) + sh_exp_bc = self.shared_experts.bc + sh_exp_t = torch.empty((1, 1, H), dtype = self.out_dtype or torch.half, device = self.device) if self.shared_gate: assert self.shared_gate.quant_type == "fp16" - sh_gate = self.shared_gate.inner.bc + sh_gate_bc = self.shared_gate.inner.bc sh_gate_t = torch.empty((1, 1, 1), dtype = self.shared_gate.out_dtype, device = self.device) + g_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.gates]) + u_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.ups]) + g_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.gates]) + u_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.ups]) + g_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.gates]) + u_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.ups]) + gu_trellis_ptr = torch.stack((g_trellis_ptr, u_trellis_ptr), dim = 0).T.contiguous().to(self.device) + gu_suh_ptr = torch.stack((g_suh_ptr, u_suh_ptr), dim = 0).T.contiguous().to(self.device) + gu_svh_ptr = torch.stack((g_svh_ptr, u_svh_ptr), dim = 0).T.contiguous().to(self.device) + + dq_temp_up = g_tensor_cache.get(device, (H, I), torch.half, "dq_temp") + dq_temp_down = dq_temp_up.view(I, H) + self.bc = ext.BC_BlockSparseMLP( + yh2, cfg.yh, + interm_gu, cfg.interm_g, cfg.interm_u, cfg.interm_a, cfg.out_d, + cfg.out_d2, sh_exp_t, sh_gate_t, + dq_temp_up, + dq_temp_down, cfg.min_expert, cfg.max_expert, self.multi_gate.ptrs_trellis, @@ -459,9 +485,15 @@ class BlockSparseMLP(Module): self.multi_down.mul1, self.activation_fn == "silu", self.activation_fn == "gelu", - sh_exp, - sh_gate, - self.act_limit + sh_exp_bc, + sh_gate_bc, + self.act_limit, + [x.inner.bc for x in self.gates], + [x.inner.bc for x in self.ups], + [x.inner.bc for x in self.downs], + gu_trellis_ptr, + gu_suh_ptr, + gu_svh_ptr ) @@ -545,45 +577,101 @@ class BlockSparseMLP(Module): if self.intermediate_size == 0 or self.num_local_experts == 0: final_hidden_states = torch.zeros_like(x, dtype = self.out_dtype) - # Torch path + # Torch/C++ path elif bsz >= self.f_threshold or not self.is_quantized: final_hidden_states = torch.zeros_like(y, dtype = self.out_dtype) - if self.routing_device is None or self.num_local_experts == self.num_experts: - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts) - else: - # TODO: profile, maybe optimize - selected_experts -= self.routing_first - invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts) - shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1) - expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:] - # routing_weights[invalid] = 0.0 + # if self.routing_device is None or self.num_local_experts == self.num_experts: + # expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts) + # else: + # selected_experts -= self.routing_first + # invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts) + # shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1) + # expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:] if self.num_local_experts is None or self.num_local_experts > 0: num_ex = self.num_local_experts or self.num_experts - expert_count = expert_mask.view(-1, num_ex).sum(dim = 0).cpu() - expert_mask = expert_mask.permute(2, 1, 0) - def mlp(exp_i, xc): - g = self.gates[exp_i].forward(xc, params) - u = self.ups[exp_i].forward(xc, params) - a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half) - self.activation_fn_call(g, u, a, self.act_limit) - return self.downs[exp_i].forward(a, params) + num_tokens, top_k = selected_experts.shape + E = self.num_local_experts + + # Flatten assignments + flat_expert_global = selected_experts.reshape(-1) # [num_tokens * top_k] + flat_weight = routing_weights.reshape(-1) # [num_tokens * top_k] + + # Token indices corresponding to each flattened assignment + flat_token = torch.arange(num_tokens, device = y.device) + flat_token = flat_token.repeat_interleave(top_k) # [num_tokens * top_k] + + if self.routing_device is None or self.num_local_experts == self.num_experts: + flat_expert_local = flat_expert_global + else: + flat_expert_local = flat_expert_global - self.routing_first + valid = (flat_expert_local >= 0) & (flat_expert_local < E) + flat_expert_local = torch.where(valid, flat_expert_local, torch.full_like(flat_expert_local, E)) + + # Group once by local expert id (including sentinel for expert-P mode) + order = flat_expert_local.argsort() + local_sorted = flat_expert_local[order] + token_sorted = flat_token[order] + weight_sorted = flat_weight[order] + + # Count how many assignments per expert + expert_count = torch.bincount(local_sorted, minlength = E + 1) + expert_ptr = torch.empty(E + 2, device = y.device, dtype = torch.long) + expert_ptr[0] = 0 + expert_ptr[1:] = expert_count.cumsum(0) + expert_ptr = expert_ptr.tolist() for expert_idx in range(num_ex): - if expert_count[expert_idx] == 0: + start = expert_ptr[expert_idx] + end = expert_ptr[expert_idx + 1] + count = end - start + if count == 0: continue - idx, top_x = torch.where(expert_mask[expert_idx]) - current_state = y[None, top_x].reshape(-1, self.hidden_size) - current_state = mlp(expert_idx, current_state) * routing_weights[top_x, idx, None] + + top_x = token_sorted[start:end] + w = weight_sorted[start:end].unsqueeze(1) + + current_state = y.index_select(0, top_x) + + if self.bc is not None: + if count <= TEMP_ROWS: + self.bc.run_single_expert(current_state, expert_idx) + current_state = self.experts_cfg.out_d2[:count] + else: + out_state = torch.empty( + (count, self.hidden_size), + dtype = self.out_dtype or torch.half, + device = self.device + ) + interm = torch.empty( + (2, count, self.intermediate_size), + dtype = self.interm_dtype, + device = self.device + ) + interm_a = interm[0] if self.interm_dtype == torch.half else torch.empty_like(interm[0], dtype = torch.half) + yh = torch.empty((2, count, self.hidden_size), dtype = torch.half, device = self.device) + self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state) + current_state = out_state + else: + def mlp(exp_i, xc): + g = self.gates[exp_i].forward(xc, params) + u = self.ups[exp_i].forward(xc, params) + a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half) + self.activation_fn_call(g, u, a, self.act_limit) + return self.downs[exp_i].forward(a, params) + + current_state = mlp(expert_idx, current_state) + + current_state.mul_(w) final_hidden_states.index_add_(0, top_x, current_state) final_hidden_states = final_hidden_states.reshape(x.shape) final_hidden_states = to2(final_hidden_states, out_dtype, self.out_dtype) - # Fused path + # Fused path, few tokens elif bsz > 1: final_hidden_states = torch.empty_like(y, dtype = self.out_dtype) @@ -592,12 +680,6 @@ class BlockSparseMLP(Module): selected_experts = selected_experts.unsqueeze(1) routing_weights = routing_weights.unsqueeze(1) - # yh = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = torch.half, device = self.device) - # interm_g = torch.empty((bsz, self.num_experts_per_tok, 1, self.intermediate_size), dtype = self.interm_dtype, device = self.device) - # interm_u = torch.empty_like(interm_g) - # interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u - # out_d = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = self.out_dtype or torch.half, device = self.device) - cfg = self.experts_cfg mine, maxe = self.routing_first, self.routing_last