mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
BlockSparseMLP: Improved batch routing
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
#include "../util.h"
|
||||
#include "../hgemm.cuh"
|
||||
#include "../quant/exl3_gemm.cuh"
|
||||
#include "../quant/hadamard.cuh"
|
||||
#include "../quant/reconstruct.cuh"
|
||||
#include "../activation.cuh"
|
||||
#include "../add.cuh"
|
||||
|
||||
@@ -69,7 +71,7 @@ void BC_BlockSparseMLP::run_bsz1_gr
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
py::gil_scoped_release _;
|
||||
//py::gil_scoped_release _;
|
||||
const at::Tensor& yi = y.unsqueeze(0);
|
||||
|
||||
exl3_mgemm_gr
|
||||
@@ -213,4 +215,258 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
graph_bsz1.launch(args, stream);
|
||||
|
||||
#endif
|
||||
#undef USE_GRAPH
|
||||
}
|
||||
|
||||
BC_BlockSparseMLP::BC_BlockSparseMLP
|
||||
(
|
||||
at::Tensor _yh2,
|
||||
at::Tensor _yh,
|
||||
at::Tensor _interm_gu,
|
||||
at::Tensor _interm_g,
|
||||
at::Tensor _interm_u,
|
||||
at::Tensor _interm_a,
|
||||
at::Tensor _out_d,
|
||||
at::Tensor _out_d2,
|
||||
c10::optional<at::Tensor> _out_d_sh,
|
||||
c10::optional<at::Tensor> _z,
|
||||
at::Tensor _dq_temp_up,
|
||||
at::Tensor _dq_temp_down,
|
||||
int _min_expert,
|
||||
int _max_expert,
|
||||
at::Tensor _gate_ptrs_trellis,
|
||||
at::Tensor _gate_ptrs_suh,
|
||||
at::Tensor _gate_ptrs_svh,
|
||||
int _gate_K,
|
||||
bool _gate_mcg,
|
||||
bool _gate_mul1,
|
||||
at::Tensor _up_ptrs_trellis,
|
||||
at::Tensor _up_ptrs_suh,
|
||||
at::Tensor _up_ptrs_svh,
|
||||
int _up_K,
|
||||
bool _up_mcg,
|
||||
bool _up_mul1,
|
||||
at::Tensor _down_ptrs_trellis,
|
||||
at::Tensor _down_ptrs_suh,
|
||||
at::Tensor _down_ptrs_svh,
|
||||
int _down_K,
|
||||
bool _down_mcg,
|
||||
bool _down_mul1,
|
||||
bool _act_silu,
|
||||
bool _act_gelu,
|
||||
std::shared_ptr<BC_GatedMLP> _shared_experts,
|
||||
std::shared_ptr<BC_LinearFP16> _shared_gate,
|
||||
float _act_limit,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _gates,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _ups,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _downs,
|
||||
at::Tensor _gu_trellis_ptr,
|
||||
at::Tensor _gu_suh_ptr,
|
||||
at::Tensor _gu_svh_ptr
|
||||
) :
|
||||
yh2 (std::move(_yh2)),
|
||||
yh (std::move(_yh)),
|
||||
interm_gu (std::move(_interm_gu)),
|
||||
interm_g (std::move(_interm_g)),
|
||||
interm_u (std::move(_interm_u)),
|
||||
interm_a (std::move(_interm_a)),
|
||||
out_d (std::move(_out_d)),
|
||||
out_d2 (std::move(_out_d2)),
|
||||
out_d_sh (std::move(_out_d_sh)),
|
||||
z (std::move(_z)),
|
||||
dq_temp_up (std::move(_dq_temp_up)),
|
||||
dq_temp_down (std::move(_dq_temp_down)),
|
||||
min_expert (_min_expert),
|
||||
max_expert (_max_expert),
|
||||
gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
|
||||
gate_ptrs_suh (std::move(_gate_ptrs_suh)),
|
||||
gate_ptrs_svh (std::move(_gate_ptrs_svh)),
|
||||
gate_K (_gate_K),
|
||||
gate_mcg (_gate_mcg),
|
||||
gate_mul1 (_gate_mul1),
|
||||
up_ptrs_trellis (std::move(_up_ptrs_trellis)),
|
||||
up_ptrs_suh (std::move(_up_ptrs_suh)),
|
||||
up_ptrs_svh (std::move(_up_ptrs_svh)),
|
||||
up_K (_up_K),
|
||||
up_mcg (_up_mcg),
|
||||
up_mul1 (_up_mul1),
|
||||
down_ptrs_trellis (std::move(_down_ptrs_trellis)),
|
||||
down_ptrs_suh (std::move(_down_ptrs_suh)),
|
||||
down_ptrs_svh (std::move(_down_ptrs_svh)),
|
||||
down_K (_down_K),
|
||||
down_mcg (_down_mcg),
|
||||
down_mul1 (_down_mul1),
|
||||
act_silu (_act_silu),
|
||||
act_gelu (_act_gelu),
|
||||
shared_experts (_shared_experts),
|
||||
shared_gate (_shared_gate),
|
||||
act_limit (_act_limit),
|
||||
gates (_gates),
|
||||
ups (_ups),
|
||||
downs (_downs),
|
||||
gu_trellis_ptr (_gu_trellis_ptr),
|
||||
gu_suh_ptr (_gu_suh_ptr),
|
||||
gu_svh_ptr (_gu_svh_ptr)
|
||||
{
|
||||
gate_ptrs_trellis_cpu = gate_ptrs_trellis.cpu();
|
||||
gate_ptrs_suh_cpu = gate_ptrs_suh.cpu();
|
||||
gate_ptrs_svh_cpu = gate_ptrs_svh.cpu();
|
||||
up_ptrs_trellis_cpu = up_ptrs_trellis.cpu();
|
||||
up_ptrs_suh_cpu = up_ptrs_suh.cpu();
|
||||
up_ptrs_svh_cpu = up_ptrs_svh.cpu();
|
||||
down_ptrs_trellis_cpu = down_ptrs_trellis.cpu();
|
||||
down_ptrs_suh_cpu = down_ptrs_suh.cpu();
|
||||
down_ptrs_svh_cpu = down_ptrs_svh.cpu();
|
||||
|
||||
max_experts_per_token = interm_g.size(0);
|
||||
max_tokens_per_expert = max_experts_per_token;
|
||||
|
||||
for (int i = 0; i < max_tokens_per_expert; ++i)
|
||||
{
|
||||
interm_g_single.push_back(interm_g.squeeze(1).slice(0, 0, i + 1));
|
||||
interm_u_single.push_back(interm_u.squeeze(1).slice(0, 0, i + 1));
|
||||
interm_a_single.push_back(interm_a.squeeze(1).slice(0, 0, i + 1));
|
||||
out_d_single.push_back(out_d.squeeze(1).slice(0, 0, i + 1));
|
||||
}
|
||||
|
||||
TORCH_CHECK(max_expert <= MAX_EXPERTS, "BC_BlockSparseMLP: Too many experts");
|
||||
|
||||
use_mgemm = gate_K == up_K;
|
||||
}
|
||||
|
||||
void BC_BlockSparseMLP::run_single_expert
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx
|
||||
)
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
|
||||
at::Tensor ai = interm_a.slice(0, 0, bsz);
|
||||
at::Tensor oi = out_d2.slice(0, 0, bsz);
|
||||
|
||||
if (use_mgemm)
|
||||
{
|
||||
at::Tensor yb = y.unsqueeze(0);
|
||||
at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
|
||||
at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
|
||||
|
||||
exl3_mgemm
|
||||
(
|
||||
yb,
|
||||
gu_trellis_ptr[expert_idx],
|
||||
gui,
|
||||
gu_suh_ptr[expert_idx],
|
||||
yh2,
|
||||
gu_svh_ptr[expert_idx],
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
gate_K,
|
||||
-1,
|
||||
gate_mcg,
|
||||
gate_mul1,
|
||||
-1,
|
||||
-1,
|
||||
0
|
||||
);
|
||||
|
||||
at::Tensor gi = gui[0];
|
||||
at::Tensor ui = gui[1];
|
||||
|
||||
if (act_silu)
|
||||
silu_mul(gi, ui, ai, act_limit);
|
||||
else if (act_gelu)
|
||||
gelu_mul(gi, ui, ai, act_limit);
|
||||
}
|
||||
else
|
||||
{
|
||||
at::Tensor gi = interm_gu.slice(0, 0, bsz);
|
||||
at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2);
|
||||
|
||||
exl3_gemm
|
||||
(
|
||||
y,
|
||||
gates[expert_idx]->trellis,
|
||||
gi,
|
||||
gates[expert_idx]->suh,
|
||||
yh,
|
||||
gates[expert_idx]->svh,
|
||||
-1,
|
||||
gate_mcg,
|
||||
gate_mul1,
|
||||
0
|
||||
);
|
||||
|
||||
exl3_gemm
|
||||
(
|
||||
y,
|
||||
ups[expert_idx]->trellis,
|
||||
ui,
|
||||
ups[expert_idx]->suh,
|
||||
yh,
|
||||
ups[expert_idx]->svh,
|
||||
-1,
|
||||
up_mcg,
|
||||
up_mul1,
|
||||
0
|
||||
);
|
||||
|
||||
if (act_silu)
|
||||
silu_mul(gi, ui, ai, act_limit);
|
||||
else if (act_gelu)
|
||||
gelu_mul(gi, ui, ai, act_limit);
|
||||
}
|
||||
|
||||
exl3_gemm
|
||||
(
|
||||
ai,
|
||||
downs[expert_idx]->trellis,
|
||||
oi,
|
||||
downs[expert_idx]->suh,
|
||||
ai,
|
||||
downs[expert_idx]->svh,
|
||||
-1,
|
||||
down_mcg,
|
||||
down_mul1,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
void BC_BlockSparseMLP::run_single_expert_dq
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx,
|
||||
at::Tensor& yh,
|
||||
at::Tensor& interm,
|
||||
at::Tensor& interm_a,
|
||||
at::Tensor& out
|
||||
)
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
|
||||
at::Tensor yh1 = yh[0];
|
||||
at::Tensor yh2 = yh[0];
|
||||
at::Tensor interm1 = interm[0];
|
||||
at::Tensor interm2 = interm[1];
|
||||
|
||||
had_r_128_dual(y, yh1, gates[expert_idx]->suh, c10::nullopt,
|
||||
y, yh2, ups[expert_idx]->suh, c10::nullopt, 1.0);
|
||||
|
||||
reconstruct(dq_temp_up, gates[expert_idx]->trellis, gate_K, gate_mcg, gate_mul1);
|
||||
hgemm(yh1, dq_temp_up, interm1);
|
||||
reconstruct(dq_temp_up, ups[expert_idx]->trellis, up_K, up_mcg, up_mul1);
|
||||
hgemm(yh2, dq_temp_up, interm2);
|
||||
|
||||
had_r_128_dual(interm1, interm1, c10::nullopt, gates[expert_idx]->svh,
|
||||
interm2, interm2, c10::nullopt, ups[expert_idx]->svh, 1.0);
|
||||
|
||||
if (act_silu)
|
||||
silu_mul(interm1, interm2, interm_a, act_limit);
|
||||
else if (act_gelu)
|
||||
gelu_mul(interm1, interm2, interm_a, act_limit);
|
||||
|
||||
had_r_128(interm_a, interm_a, downs[expert_idx]->suh, c10::nullopt, 1.0);
|
||||
reconstruct(dq_temp_down, downs[expert_idx]->trellis, down_K, down_mcg, down_mul1);
|
||||
hgemm(interm_a, dq_temp_down, out);
|
||||
had_r_128(out, out, c10::nullopt, downs[expert_idx]->svh, 1.0);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ namespace py = pybind11;
|
||||
#include "linear.h"
|
||||
#include "../graph.cuh"
|
||||
|
||||
#define MAX_EXPERTS 512
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
|
||||
int bsz,
|
||||
const py::object& cfg,
|
||||
@@ -18,30 +20,35 @@ std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
|
||||
|
||||
struct BC_BlockSparseMLP
|
||||
{
|
||||
at::Tensor yh2;
|
||||
at::Tensor yh;
|
||||
at::Tensor interm_gu;
|
||||
at::Tensor interm_g;
|
||||
at::Tensor interm_u;
|
||||
at::Tensor interm_a;
|
||||
at::Tensor out_d;
|
||||
at::Tensor out_d2;
|
||||
c10::optional<at::Tensor> out_d_sh;
|
||||
c10::optional<at::Tensor> z;
|
||||
at::Tensor dq_temp_up;
|
||||
at::Tensor dq_temp_down;
|
||||
int min_expert;
|
||||
int max_expert;
|
||||
at::Tensor gate_ptrs_trellis;
|
||||
at::Tensor gate_ptrs_suh;
|
||||
at::Tensor gate_ptrs_svh;
|
||||
at::Tensor gate_ptrs_trellis; at::Tensor gate_ptrs_trellis_cpu;
|
||||
at::Tensor gate_ptrs_suh; at::Tensor gate_ptrs_suh_cpu;
|
||||
at::Tensor gate_ptrs_svh; at::Tensor gate_ptrs_svh_cpu;
|
||||
int gate_K;
|
||||
bool gate_mcg;
|
||||
bool gate_mul1;
|
||||
at::Tensor up_ptrs_trellis;
|
||||
at::Tensor up_ptrs_suh;
|
||||
at::Tensor up_ptrs_svh;
|
||||
at::Tensor up_ptrs_trellis; at::Tensor up_ptrs_trellis_cpu;
|
||||
at::Tensor up_ptrs_suh; at::Tensor up_ptrs_suh_cpu;
|
||||
at::Tensor up_ptrs_svh; at::Tensor up_ptrs_svh_cpu;
|
||||
int up_K;
|
||||
bool up_mcg;
|
||||
bool up_mul1;
|
||||
at::Tensor down_ptrs_trellis;
|
||||
at::Tensor down_ptrs_suh;
|
||||
at::Tensor down_ptrs_svh;
|
||||
at::Tensor down_ptrs_trellis; at::Tensor down_ptrs_trellis_cpu;
|
||||
at::Tensor down_ptrs_suh; at::Tensor down_ptrs_suh_cpu;
|
||||
at::Tensor down_ptrs_svh; at::Tensor down_ptrs_svh_cpu;
|
||||
int down_K;
|
||||
bool down_mcg;
|
||||
bool down_mul1;
|
||||
@@ -50,18 +57,38 @@ struct BC_BlockSparseMLP
|
||||
std::shared_ptr<BC_GatedMLP> shared_experts;
|
||||
std::shared_ptr<BC_LinearFP16> shared_gate;
|
||||
float act_limit;
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> gates;
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> ups;
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> downs;
|
||||
at::Tensor gu_trellis_ptr;
|
||||
at::Tensor gu_suh_ptr;
|
||||
at::Tensor gu_svh_ptr;
|
||||
|
||||
int max_experts_per_token;
|
||||
int max_tokens_per_expert;
|
||||
std::vector<at::Tensor> interm_g_single;
|
||||
std::vector<at::Tensor> interm_u_single;
|
||||
std::vector<at::Tensor> interm_a_single;
|
||||
std::vector<at::Tensor> out_d_single;
|
||||
|
||||
bool use_mgemm;
|
||||
|
||||
Graph graph_bsz1;
|
||||
|
||||
BC_BlockSparseMLP
|
||||
(
|
||||
at::Tensor _yh2,
|
||||
at::Tensor _yh,
|
||||
at::Tensor _interm_gu,
|
||||
at::Tensor _interm_g,
|
||||
at::Tensor _interm_u,
|
||||
at::Tensor _interm_a,
|
||||
at::Tensor _out_d,
|
||||
at::Tensor _out_d2,
|
||||
c10::optional<at::Tensor> _out_d_sh,
|
||||
c10::optional<at::Tensor> _z,
|
||||
at::Tensor _dq_temp_up,
|
||||
at::Tensor _dq_temp_down,
|
||||
int _min_expert,
|
||||
int _max_expert,
|
||||
at::Tensor _gate_ptrs_trellis,
|
||||
@@ -86,41 +113,14 @@ struct BC_BlockSparseMLP
|
||||
bool _act_gelu,
|
||||
std::shared_ptr<BC_GatedMLP> _shared_experts,
|
||||
std::shared_ptr<BC_LinearFP16> _shared_gate,
|
||||
float _act_limit
|
||||
) :
|
||||
yh (std::move(_yh)),
|
||||
interm_g (std::move(_interm_g)),
|
||||
interm_u (std::move(_interm_u)),
|
||||
interm_a (std::move(_interm_a)),
|
||||
out_d (std::move(_out_d)),
|
||||
out_d_sh (std::move(_out_d_sh)),
|
||||
z (std::move(_z)),
|
||||
min_expert (_min_expert),
|
||||
max_expert (_max_expert),
|
||||
gate_ptrs_trellis (std::move(_gate_ptrs_trellis)),
|
||||
gate_ptrs_suh (std::move(_gate_ptrs_suh)),
|
||||
gate_ptrs_svh (std::move(_gate_ptrs_svh)),
|
||||
gate_K (_gate_K),
|
||||
gate_mcg (_gate_mcg),
|
||||
gate_mul1 (_gate_mul1),
|
||||
up_ptrs_trellis (std::move(_up_ptrs_trellis)),
|
||||
up_ptrs_suh (std::move(_up_ptrs_suh)),
|
||||
up_ptrs_svh (std::move(_up_ptrs_svh)),
|
||||
up_K (_up_K),
|
||||
up_mcg (_up_mcg),
|
||||
up_mul1 (_up_mul1),
|
||||
down_ptrs_trellis (std::move(_down_ptrs_trellis)),
|
||||
down_ptrs_suh (std::move(_down_ptrs_suh)),
|
||||
down_ptrs_svh (std::move(_down_ptrs_svh)),
|
||||
down_K (_down_K),
|
||||
down_mcg (_down_mcg),
|
||||
down_mul1 (_down_mul1),
|
||||
act_silu (_act_silu),
|
||||
act_gelu (_act_gelu),
|
||||
shared_experts (_shared_experts),
|
||||
shared_gate (_shared_gate),
|
||||
act_limit (_act_limit)
|
||||
{}
|
||||
float _act_limit,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _gates,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _ups,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>> _downs,
|
||||
at::Tensor _gu_trellis_ptr,
|
||||
at::Tensor _gu_suh_ptr,
|
||||
at::Tensor _gu_svh_ptr
|
||||
);
|
||||
|
||||
void run_bsz1_gr
|
||||
(
|
||||
@@ -136,4 +136,20 @@ struct BC_BlockSparseMLP
|
||||
at::Tensor& selected_experts,
|
||||
at::Tensor& routing_weights
|
||||
);
|
||||
|
||||
void run_single_expert
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx
|
||||
);
|
||||
|
||||
void run_single_expert_dq
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx,
|
||||
at::Tensor& yh,
|
||||
at::Tensor& interm,
|
||||
at::Tensor& interm_a,
|
||||
at::Tensor& out
|
||||
);
|
||||
};
|
||||
@@ -6,8 +6,13 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
c10::optional<at::Tensor>,
|
||||
c10::optional<at::Tensor>,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
int,
|
||||
int,
|
||||
at::Tensor,
|
||||
@@ -32,15 +37,26 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
bool,
|
||||
std::shared_ptr<BC_GatedMLP>,
|
||||
std::shared_ptr<BC_LinearFP16>,
|
||||
float
|
||||
float,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>>,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>>,
|
||||
std::vector<std::shared_ptr<BC_LinearEXL3>>,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor
|
||||
>(),
|
||||
py::arg("yh2"),
|
||||
py::arg("yh"),
|
||||
py::arg("interm_gu"),
|
||||
py::arg("interm_g"),
|
||||
py::arg("interm_u"),
|
||||
py::arg("interm_a"),
|
||||
py::arg("out_d"),
|
||||
py::arg("out_d2"),
|
||||
py::arg("out_d_sh"),
|
||||
py::arg("z"),
|
||||
py::arg("dq_temp_up"),
|
||||
py::arg("dq_temp_down"),
|
||||
py::arg("min_expert"),
|
||||
py::arg("max_expert"),
|
||||
py::arg("gate_ptrs_trellis"),
|
||||
@@ -65,6 +81,14 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
py::arg("act_gelu"),
|
||||
py::arg("shared_experts"),
|
||||
py::arg("shared_gate"),
|
||||
py::arg("act_limit")
|
||||
py::arg("act_limit"),
|
||||
py::arg("gates"),
|
||||
py::arg("ups"),
|
||||
py::arg("downs"),
|
||||
py::arg("gu_trellis_ptr"),
|
||||
py::arg("gu_suh_ptr"),
|
||||
py::arg("gu_svh_ptr")
|
||||
)
|
||||
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);
|
||||
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1)
|
||||
.def("run_single_expert", &BC_BlockSparseMLP::run_single_expert)
|
||||
.def("run_single_expert_dq", &BC_BlockSparseMLP::run_single_expert_dq);
|
||||
|
||||
@@ -36,6 +36,52 @@ void had_ff_r_128_kernel
|
||||
had_ff_r_128_inner(input_ptr, output_ptr, pre_scale, post_scale, r_scale);
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(32)
|
||||
void had_hf_r_128_dual_kernel
|
||||
(
|
||||
const half* __restrict__ input1_ptr,
|
||||
half* __restrict__ output1_ptr,
|
||||
const half* __restrict__ pre1_scale,
|
||||
const half* __restrict__ post1_scale,
|
||||
const half* __restrict__ input2_ptr,
|
||||
half* __restrict__ output2_ptr,
|
||||
const half* __restrict__ pre2_scale,
|
||||
const half* __restrict__ post2_scale,
|
||||
const float r_scale
|
||||
)
|
||||
{
|
||||
input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
had_hf_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale);
|
||||
|
||||
input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
had_hf_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale);
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(32)
|
||||
void had_ff_r_128_dual_kernel
|
||||
(
|
||||
const float* __restrict__ input1_ptr,
|
||||
float* __restrict__ output1_ptr,
|
||||
const half* __restrict__ pre1_scale,
|
||||
const half* __restrict__ post1_scale,
|
||||
const float* __restrict__ input2_ptr,
|
||||
float* __restrict__ output2_ptr,
|
||||
const half* __restrict__ pre2_scale,
|
||||
const half* __restrict__ post2_scale,
|
||||
const float r_scale
|
||||
)
|
||||
{
|
||||
input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
had_ff_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale);
|
||||
|
||||
input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128;
|
||||
had_ff_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale);
|
||||
}
|
||||
|
||||
/*
|
||||
Compute y = (x.view(-1, 128) @ had_128).view(x.shape)
|
||||
Works inplace if y == x
|
||||
@@ -94,4 +140,73 @@ void had_r_128
|
||||
}
|
||||
|
||||
else TORCH_CHECK(false, "unsupported datatype");
|
||||
}
|
||||
}
|
||||
|
||||
void had_r_128_dual
|
||||
(
|
||||
const at::Tensor& input1,
|
||||
const at::Tensor& output1,
|
||||
const c10::optional<at::Tensor>& pre_scale1,
|
||||
const c10::optional<at::Tensor>& post_scale1,
|
||||
const at::Tensor& input2,
|
||||
const at::Tensor& output2,
|
||||
const c10::optional<at::Tensor>& pre_scale2,
|
||||
const c10::optional<at::Tensor>& post_scale2,
|
||||
const float scale
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(input1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_SHAPES_FULL(input1, output1);
|
||||
TORCH_CHECK_SHAPES_FULL(input1, input2);
|
||||
TORCH_CHECK_SHAPES_FULL(output1, output2);
|
||||
TORCH_CHECK_DIM(input1, 2);
|
||||
TORCH_CHECK_DIV(input1, 1, 128);
|
||||
int rows = input1.size(0);
|
||||
int cols = input1.size(1);
|
||||
|
||||
int blocks = cols / 128;
|
||||
float r_scale = scale * 0.088388347648f; // scale / sqrt(128)
|
||||
|
||||
dim3 blockDim(32);
|
||||
dim3 gridDim(rows, blocks);
|
||||
|
||||
if (input1.dtype() == at::kHalf)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(output1, kHalf);
|
||||
had_hf_r_128_dual_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) input1.data_ptr(),
|
||||
(half*) output1.data_ptr(),
|
||||
(const half*) OPTPTR(pre_scale1),
|
||||
(const half*) OPTPTR(post_scale1),
|
||||
(const half*) input2.data_ptr(),
|
||||
(half*) output2.data_ptr(),
|
||||
(const half*) OPTPTR(pre_scale2),
|
||||
(const half*) OPTPTR(post_scale2),
|
||||
r_scale
|
||||
);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
else if (input1.dtype() == at::kFloat)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(output1, kFloat);
|
||||
had_ff_r_128_dual_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) input1.data_ptr(),
|
||||
(float*) output1.data_ptr(),
|
||||
(const half*) OPTPTR(pre_scale1),
|
||||
(const half*) OPTPTR(post_scale1),
|
||||
(const float*) input2.data_ptr(),
|
||||
(float*) output2.data_ptr(),
|
||||
(const half*) OPTPTR(pre_scale2),
|
||||
(const half*) OPTPTR(post_scale2),
|
||||
r_scale
|
||||
);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
else TORCH_CHECK(false, "unsupported datatype");
|
||||
}
|
||||
|
||||
@@ -10,3 +10,16 @@ void had_r_128
|
||||
const c10::optional<at::Tensor>& post_scale,
|
||||
const float scale
|
||||
);
|
||||
|
||||
void had_r_128_dual
|
||||
(
|
||||
const at::Tensor& input1,
|
||||
const at::Tensor& output1,
|
||||
const c10::optional<at::Tensor>& pre_scale1,
|
||||
const c10::optional<at::Tensor>& post_scale1,
|
||||
const at::Tensor& input2,
|
||||
const at::Tensor& output2,
|
||||
const c10::optional<at::Tensor>& pre_scale2,
|
||||
const c10::optional<at::Tensor>& post_scale2,
|
||||
const float scale
|
||||
);
|
||||
|
||||
@@ -11,7 +11,9 @@ from dataclasses import dataclass
|
||||
from .mlp import MLP, GatedMLP
|
||||
from ..model.model_tp_alloc import TPAllocation
|
||||
from ..util import profile_opt
|
||||
from ..util.tensor import g_tensor_cache
|
||||
|
||||
TEMP_ROWS = 32
|
||||
|
||||
@dataclass
|
||||
class RoutingCFG:
|
||||
@@ -141,6 +143,7 @@ class ExpertsCFG:
|
||||
interm_u: torch.Tensor
|
||||
interm_a: torch.Tensor
|
||||
out_d: torch.Tensor
|
||||
out_d2: torch.Tensor
|
||||
min_expert: int
|
||||
max_expert: int
|
||||
|
||||
@@ -208,6 +211,9 @@ class BlockSparseMLP(Module):
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
|
||||
assert num_experts_per_tok <= TEMP_ROWS, \
|
||||
f"Too many experts per token, max supported is {TEMP_ROWS}"
|
||||
|
||||
if routing_gate is None and key_routing_gate is None:
|
||||
self.routing_gate = None
|
||||
elif routing_gate is None:
|
||||
@@ -377,66 +383,86 @@ class BlockSparseMLP(Module):
|
||||
self.multi_up = MultiLinear(self.device, self.ups)
|
||||
self.multi_down = MultiLinear(self.device, self.downs)
|
||||
|
||||
yh = torch.empty(
|
||||
(self.num_experts_per_tok, 1, self.hidden_size),
|
||||
dtype = torch.half,
|
||||
device = self.device
|
||||
)
|
||||
interm_g = torch.empty(
|
||||
(self.num_experts_per_tok, 1, self.intermediate_size),
|
||||
dtype = self.interm_dtype,
|
||||
device = self.device
|
||||
)
|
||||
interm_u = torch.empty_like(interm_g)
|
||||
interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u
|
||||
out_d = torch.empty(
|
||||
(self.num_experts_per_tok, 1, self.hidden_size),
|
||||
dtype = self.out_dtype or torch.half,
|
||||
device = self.device
|
||||
)
|
||||
# Temp buffers
|
||||
numex = self.num_experts_per_tok
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
device = self.device
|
||||
|
||||
temp_hidden = g_tensor_cache.get(device, (TEMP_ROWS * 2, H), torch.half, "temp_hidden")
|
||||
temp_interm = g_tensor_cache.get(device, (TEMP_ROWS * 2, I), self.interm_dtype, "temp_interm")
|
||||
temp_activa = g_tensor_cache.get(device, (TEMP_ROWS, I), torch.half, "temp_activa")
|
||||
temp_output = g_tensor_cache.get(device, (TEMP_ROWS, H), self.out_dtype or torch.half, "temp_output")
|
||||
|
||||
yh = temp_hidden[:numex].view(numex, 1, H)
|
||||
interm_g = temp_interm[:numex].view(numex, 1, I)
|
||||
interm_u = temp_interm[numex:numex*2].view(numex, 1, I)
|
||||
interm_a = temp_activa[:numex].view(numex, 1, I)
|
||||
yh2 = temp_hidden
|
||||
interm_gu = temp_interm
|
||||
out_d = temp_output[:numex].view(numex, 1, H)
|
||||
out_d2 = temp_output
|
||||
|
||||
# Expert interval for split module (-1, -1) indicate no split
|
||||
mine, maxe = self.routing_first, self.routing_last
|
||||
if mine is None or maxe - mine == self.num_experts:
|
||||
mine, maxe = -1, -1
|
||||
self.experts_cfg = ExpertsCFG(
|
||||
|
||||
cfg = ExpertsCFG(
|
||||
yh = yh,
|
||||
interm_g = interm_g,
|
||||
interm_u = interm_u,
|
||||
interm_a = interm_a,
|
||||
out_d = out_d,
|
||||
out_d2 = out_d2,
|
||||
min_expert = mine,
|
||||
max_expert = maxe,
|
||||
)
|
||||
self.experts_cfg = cfg
|
||||
|
||||
cfg = self.experts_cfg
|
||||
if self.is_quantized:
|
||||
|
||||
sh_exp = None
|
||||
# Embed bound classes for shared experts and shared gate
|
||||
sh_exp_bc = None
|
||||
sh_exp_t = None
|
||||
sh_gate = None
|
||||
sh_gate_bc = None
|
||||
sh_gate_t = None
|
||||
self.bc_sh_exp = False
|
||||
if self.shared_experts and isinstance(self.shared_experts, GatedMLP) and self.shared_experts.bc is not None:
|
||||
self.bc_sh_exp = True
|
||||
sh_exp = self.shared_experts.bc
|
||||
sh_exp_t = torch.empty(
|
||||
(1, 1, self.hidden_size),
|
||||
dtype = self.out_dtype or torch.half,
|
||||
device = self.device
|
||||
)
|
||||
sh_exp_bc = self.shared_experts.bc
|
||||
sh_exp_t = torch.empty((1, 1, H), dtype = self.out_dtype or torch.half, device = self.device)
|
||||
if self.shared_gate:
|
||||
assert self.shared_gate.quant_type == "fp16"
|
||||
sh_gate = self.shared_gate.inner.bc
|
||||
sh_gate_bc = self.shared_gate.inner.bc
|
||||
sh_gate_t = torch.empty((1, 1, 1), dtype = self.shared_gate.out_dtype, device = self.device)
|
||||
|
||||
g_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.gates])
|
||||
u_trellis_ptr = torch.tensor([l.inner.trellis.data_ptr() for l in self.ups])
|
||||
g_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.gates])
|
||||
u_suh_ptr = torch.tensor([l.inner.suh.data_ptr() for l in self.ups])
|
||||
g_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.gates])
|
||||
u_svh_ptr = torch.tensor([l.inner.svh.data_ptr() for l in self.ups])
|
||||
gu_trellis_ptr = torch.stack((g_trellis_ptr, u_trellis_ptr), dim = 0).T.contiguous().to(self.device)
|
||||
gu_suh_ptr = torch.stack((g_suh_ptr, u_suh_ptr), dim = 0).T.contiguous().to(self.device)
|
||||
gu_svh_ptr = torch.stack((g_svh_ptr, u_svh_ptr), dim = 0).T.contiguous().to(self.device)
|
||||
|
||||
dq_temp_up = g_tensor_cache.get(device, (H, I), torch.half, "dq_temp")
|
||||
dq_temp_down = dq_temp_up.view(I, H)
|
||||
|
||||
self.bc = ext.BC_BlockSparseMLP(
|
||||
yh2,
|
||||
cfg.yh,
|
||||
interm_gu,
|
||||
cfg.interm_g,
|
||||
cfg.interm_u,
|
||||
cfg.interm_a,
|
||||
cfg.out_d,
|
||||
cfg.out_d2,
|
||||
sh_exp_t,
|
||||
sh_gate_t,
|
||||
dq_temp_up,
|
||||
dq_temp_down,
|
||||
cfg.min_expert,
|
||||
cfg.max_expert,
|
||||
self.multi_gate.ptrs_trellis,
|
||||
@@ -459,9 +485,15 @@ class BlockSparseMLP(Module):
|
||||
self.multi_down.mul1,
|
||||
self.activation_fn == "silu",
|
||||
self.activation_fn == "gelu",
|
||||
sh_exp,
|
||||
sh_gate,
|
||||
self.act_limit
|
||||
sh_exp_bc,
|
||||
sh_gate_bc,
|
||||
self.act_limit,
|
||||
[x.inner.bc for x in self.gates],
|
||||
[x.inner.bc for x in self.ups],
|
||||
[x.inner.bc for x in self.downs],
|
||||
gu_trellis_ptr,
|
||||
gu_suh_ptr,
|
||||
gu_svh_ptr
|
||||
)
|
||||
|
||||
|
||||
@@ -545,45 +577,101 @@ class BlockSparseMLP(Module):
|
||||
if self.intermediate_size == 0 or self.num_local_experts == 0:
|
||||
final_hidden_states = torch.zeros_like(x, dtype = self.out_dtype)
|
||||
|
||||
# Torch path
|
||||
# Torch/C++ path
|
||||
elif bsz >= self.f_threshold or not self.is_quantized:
|
||||
final_hidden_states = torch.zeros_like(y, dtype = self.out_dtype)
|
||||
|
||||
if self.routing_device is None or self.num_local_experts == self.num_experts:
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts)
|
||||
else:
|
||||
# TODO: profile, maybe optimize
|
||||
selected_experts -= self.routing_first
|
||||
invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts)
|
||||
shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1)
|
||||
expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:]
|
||||
# routing_weights[invalid] = 0.0
|
||||
# if self.routing_device is None or self.num_local_experts == self.num_experts:
|
||||
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = self.num_local_experts)
|
||||
# else:
|
||||
# selected_experts -= self.routing_first
|
||||
# invalid = (selected_experts < 0) | (selected_experts >= self.num_local_experts)
|
||||
# shifted = torch.where(invalid, torch.zeros_like(selected_experts), selected_experts + 1)
|
||||
# expert_mask = F.one_hot(shifted, num_classes = self.num_local_experts + 1)[..., 1:]
|
||||
|
||||
if self.num_local_experts is None or self.num_local_experts > 0:
|
||||
|
||||
num_ex = self.num_local_experts or self.num_experts
|
||||
expert_count = expert_mask.view(-1, num_ex).sum(dim = 0).cpu()
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
|
||||
def mlp(exp_i, xc):
|
||||
g = self.gates[exp_i].forward(xc, params)
|
||||
u = self.ups[exp_i].forward(xc, params)
|
||||
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
|
||||
self.activation_fn_call(g, u, a, self.act_limit)
|
||||
return self.downs[exp_i].forward(a, params)
|
||||
num_tokens, top_k = selected_experts.shape
|
||||
E = self.num_local_experts
|
||||
|
||||
# Flatten assignments
|
||||
flat_expert_global = selected_experts.reshape(-1) # [num_tokens * top_k]
|
||||
flat_weight = routing_weights.reshape(-1) # [num_tokens * top_k]
|
||||
|
||||
# Token indices corresponding to each flattened assignment
|
||||
flat_token = torch.arange(num_tokens, device = y.device)
|
||||
flat_token = flat_token.repeat_interleave(top_k) # [num_tokens * top_k]
|
||||
|
||||
if self.routing_device is None or self.num_local_experts == self.num_experts:
|
||||
flat_expert_local = flat_expert_global
|
||||
else:
|
||||
flat_expert_local = flat_expert_global - self.routing_first
|
||||
valid = (flat_expert_local >= 0) & (flat_expert_local < E)
|
||||
flat_expert_local = torch.where(valid, flat_expert_local, torch.full_like(flat_expert_local, E))
|
||||
|
||||
# Group once by local expert id (including sentinel for expert-P mode)
|
||||
order = flat_expert_local.argsort()
|
||||
local_sorted = flat_expert_local[order]
|
||||
token_sorted = flat_token[order]
|
||||
weight_sorted = flat_weight[order]
|
||||
|
||||
# Count how many assignments per expert
|
||||
expert_count = torch.bincount(local_sorted, minlength = E + 1)
|
||||
expert_ptr = torch.empty(E + 2, device = y.device, dtype = torch.long)
|
||||
expert_ptr[0] = 0
|
||||
expert_ptr[1:] = expert_count.cumsum(0)
|
||||
expert_ptr = expert_ptr.tolist()
|
||||
|
||||
for expert_idx in range(num_ex):
|
||||
if expert_count[expert_idx] == 0:
|
||||
start = expert_ptr[expert_idx]
|
||||
end = expert_ptr[expert_idx + 1]
|
||||
count = end - start
|
||||
if count == 0:
|
||||
continue
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
current_state = y[None, top_x].reshape(-1, self.hidden_size)
|
||||
current_state = mlp(expert_idx, current_state) * routing_weights[top_x, idx, None]
|
||||
|
||||
top_x = token_sorted[start:end]
|
||||
w = weight_sorted[start:end].unsqueeze(1)
|
||||
|
||||
current_state = y.index_select(0, top_x)
|
||||
|
||||
if self.bc is not None:
|
||||
if count <= TEMP_ROWS:
|
||||
self.bc.run_single_expert(current_state, expert_idx)
|
||||
current_state = self.experts_cfg.out_d2[:count]
|
||||
else:
|
||||
out_state = torch.empty(
|
||||
(count, self.hidden_size),
|
||||
dtype = self.out_dtype or torch.half,
|
||||
device = self.device
|
||||
)
|
||||
interm = torch.empty(
|
||||
(2, count, self.intermediate_size),
|
||||
dtype = self.interm_dtype,
|
||||
device = self.device
|
||||
)
|
||||
interm_a = interm[0] if self.interm_dtype == torch.half else torch.empty_like(interm[0], dtype = torch.half)
|
||||
yh = torch.empty((2, count, self.hidden_size), dtype = torch.half, device = self.device)
|
||||
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
|
||||
current_state = out_state
|
||||
else:
|
||||
def mlp(exp_i, xc):
|
||||
g = self.gates[exp_i].forward(xc, params)
|
||||
u = self.ups[exp_i].forward(xc, params)
|
||||
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
|
||||
self.activation_fn_call(g, u, a, self.act_limit)
|
||||
return self.downs[exp_i].forward(a, params)
|
||||
|
||||
current_state = mlp(expert_idx, current_state)
|
||||
|
||||
current_state.mul_(w)
|
||||
final_hidden_states.index_add_(0, top_x, current_state)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(x.shape)
|
||||
final_hidden_states = to2(final_hidden_states, out_dtype, self.out_dtype)
|
||||
|
||||
# Fused path
|
||||
# Fused path, few tokens
|
||||
elif bsz > 1:
|
||||
|
||||
final_hidden_states = torch.empty_like(y, dtype = self.out_dtype)
|
||||
@@ -592,12 +680,6 @@ class BlockSparseMLP(Module):
|
||||
selected_experts = selected_experts.unsqueeze(1)
|
||||
routing_weights = routing_weights.unsqueeze(1)
|
||||
|
||||
# yh = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = torch.half, device = self.device)
|
||||
# interm_g = torch.empty((bsz, self.num_experts_per_tok, 1, self.intermediate_size), dtype = self.interm_dtype, device = self.device)
|
||||
# interm_u = torch.empty_like(interm_g)
|
||||
# interm_a = torch.empty_like(interm_u, dtype = torch.half) if self.interm_dtype != torch.half else interm_u
|
||||
# out_d = torch.empty((bsz, self.num_experts_per_tok, 1, self.hidden_size), dtype = self.out_dtype or torch.half, device = self.device)
|
||||
|
||||
cfg = self.experts_cfg
|
||||
|
||||
mine, maxe = self.routing_first, self.routing_last
|
||||
|
||||
Reference in New Issue
Block a user