BlockSparseMLP: Add single expert graph

This commit is contained in:
turboderp
2026-03-07 04:55:37 +01:00
parent 7ad51c0422
commit 0e8dd89874
6 changed files with 129 additions and 45 deletions

View File

@@ -14,7 +14,11 @@ enum GraphedParams
GP_end,
GP_gemm_A,
GP_gemm_B_trellis,
GP_gemm_C,
GP_gemm_B_suh,
GP_gemm_A_had,
GP_gemm_B_svh,
GP_mgemm,
GP_mgemm_A,

View File

@@ -8,6 +8,7 @@
#include "../quant/exl3_gemm.cuh"
#include "../quant/hadamard.cuh"
#include "../quant/reconstruct.cuh"
#include "../quant/exl3_devctx.cuh"
#include "../activation.cuh"
#include "../add.cuh"
@@ -336,10 +337,11 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
use_mgemm = gate_K == up_K;
}
void BC_BlockSparseMLP::run_single_expert
void BC_BlockSparseMLP::run_single_expert_gr
(
const at::Tensor& y,
const int expert_idx
const int expert_idx,
Graph* graph
)
{
int bsz = y.size(0);
@@ -347,45 +349,45 @@ void BC_BlockSparseMLP::run_single_expert
at::Tensor ai = interm_a2.slice(0, 0, bsz);
at::Tensor oi = out_d2.slice(0, 0, bsz);
if (use_mgemm)
{
at::Tensor yb = y.unsqueeze(0);
at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
exl3_mgemm
(
yb,
gu_trellis_ptr[expert_idx],
gui,
gu_suh_ptr[expert_idx],
yh2,
gu_svh_ptr[expert_idx],
c10::nullopt,
c10::nullopt,
gate_K,
-1,
gate_mcg,
gate_mul1,
-1,
-1,
0
);
at::Tensor gi = gui[0];
at::Tensor ui = gui[1];
if (act_silu)
silu_mul(gi, ui, ai, act_limit);
else if (act_gelu)
gelu_mul(gi, ui, ai, act_limit);
}
else
// if (use_mgemm)
// {
// at::Tensor yb = y.unsqueeze(0);
// at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
// at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
//
// exl3_mgemm
// (
// yb,
// gu_trellis_ptr[expert_idx],
// gui,
// gu_suh_ptr[expert_idx],
// yh2,
// gu_svh_ptr[expert_idx],
// c10::nullopt,
// c10::nullopt,
// gate_K,
// -1,
// gate_mcg,
// gate_mul1,
// -1,
// -1,
// 0
// );
//
// at::Tensor gi = gui[0];
// at::Tensor ui = gui[1];
//
// if (act_silu)
// silu_mul(gi, ui, ai, act_limit);
// else if (act_gelu)
// gelu_mul(gi, ui, ai, act_limit);
// }
// else
{
at::Tensor gi = interm_gu.slice(0, 0, bsz);
at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2);
exl3_gemm
exl3_gemm_gr
(
y,
gates[expert_idx]->trellis,
@@ -396,10 +398,11 @@ void BC_BlockSparseMLP::run_single_expert
-1,
gate_mcg,
gate_mul1,
0
0,
graph
);
exl3_gemm
exl3_gemm_gr
(
y,
ups[expert_idx]->trellis,
@@ -410,16 +413,17 @@ void BC_BlockSparseMLP::run_single_expert
-1,
up_mcg,
up_mul1,
0
0,
graph
);
if (act_silu)
silu_mul(gi, ui, ai, act_limit);
silu_mul_gr(gi, ui, ai, act_limit, graph);
else if (act_gelu)
gelu_mul(gi, ui, ai, act_limit);
gelu_mul_gr(gi, ui, ai, act_limit, graph);
}
exl3_gemm
exl3_gemm_gr
(
ai,
downs[expert_idx]->trellis,
@@ -430,10 +434,64 @@ void BC_BlockSparseMLP::run_single_expert
-1,
down_mcg,
down_mul1,
0
0,
graph
);
}
void BC_BlockSparseMLP::run_single_expert
(
const at::Tensor& y,
const int expert_idx
)
{
int bsz = y.size(0);
TORCH_CHECK(bsz <= TEMP_ROWS);
int graphidx = bsz - 1;
c10::cuda::CUDAGuard device_guard(y.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
#define USE_GRAPH
#ifndef USE_GRAPH
run_single_expert_gr(y, expert_idx, nullptr);
#else
if (!graph_single[graphidx].ready)
{
prepare_ctx(y.get_device());
graph_single[graphidx].capture_begin();
run_single_expert_gr(y, expert_idx, &graph_single[graphidx]);
graph_single[graphidx].capture_end();
}
auto args = std::vector<PPTR>
{
PPTR(GP_gemm_A, (void*) y.data_ptr()),
PPTR(GP_gemm_B_trellis, (void*) gates[expert_idx]->trellis.data_ptr()),
PPTR(GP_gemm_B_suh, (void*) gates[expert_idx]->suh.data_ptr()),
PPTR(GP_gemm_B_svh, (void*) gates[expert_idx]->svh.data_ptr()),
PPTR(GP_end, nullptr),
PPTR(GP_gemm_A, (void*) y.data_ptr()),
PPTR(GP_gemm_B_trellis, (void*) ups[expert_idx]->trellis.data_ptr()),
PPTR(GP_gemm_B_suh, (void*) ups[expert_idx]->suh.data_ptr()),
PPTR(GP_gemm_B_svh, (void*) ups[expert_idx]->svh.data_ptr()),
PPTR(GP_end, nullptr),
PPTR(GP_gemm_B_trellis, (void*) downs[expert_idx]->trellis.data_ptr()),
PPTR(GP_gemm_B_suh, (void*) downs[expert_idx]->suh.data_ptr()),
PPTR(GP_gemm_B_svh, (void*) downs[expert_idx]->svh.data_ptr()),
};
graph_single[graphidx].launch(args, stream);
#endif
#undef USE_GRAPH
}
void BC_BlockSparseMLP::run_single_expert_dq
(
const at::Tensor& y,

View File

@@ -10,6 +10,7 @@ namespace py = pybind11;
#include "../graph.cuh"
#define MAX_EXPERTS 512
#define TEMP_ROWS 32
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
int bsz,
@@ -75,6 +76,7 @@ struct BC_BlockSparseMLP
bool use_mgemm;
Graph graph_bsz1;
Graph graph_single[TEMP_ROWS];
BC_BlockSparseMLP
(
@@ -139,6 +141,13 @@ struct BC_BlockSparseMLP
at::Tensor& routing_weights
);
void run_single_expert_gr
(
const at::Tensor& y,
const int expert_idx,
Graph* graph
);
void run_single_expert
(
const at::Tensor& y,

View File

@@ -76,4 +76,11 @@ int g_get_cc(int device)
int g_get_num_sms(int device)
{
return DevCtx::instance().get_num_sms(device);
}
}
void prepare_ctx(int device)
{
DevCtx::instance().get_num_sms(device);
DevCtx::instance().get_cc(device);
DevCtx::instance().get_locks(device);
}

View File

@@ -42,3 +42,5 @@ private:
int g_get_cc(int device);
int g_get_num_sms(int device);
void prepare_ctx(int device);

View File

@@ -156,7 +156,11 @@ int exl3_gemm_gr
);
if (graph) graph->record_param((void*) kernel, GP_gemm_A, 0);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_trellis, 1);
if (graph) graph->record_param((void*) kernel, GP_gemm_C, 2);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_suh, 7);
if (graph) graph->record_param((void*) kernel, GP_gemm_A_had, 8);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_svh, 9);
if (graph) graph->record_param((void*) kernel, GP_end, 0);
cuda_check(cudaPeekAtLastError());