mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
BlockSparseMLP: Add single expert graph
This commit is contained in:
@@ -14,7 +14,11 @@ enum GraphedParams
|
||||
GP_end,
|
||||
|
||||
GP_gemm_A,
|
||||
GP_gemm_B_trellis,
|
||||
GP_gemm_C,
|
||||
GP_gemm_B_suh,
|
||||
GP_gemm_A_had,
|
||||
GP_gemm_B_svh,
|
||||
|
||||
GP_mgemm,
|
||||
GP_mgemm_A,
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "../quant/exl3_gemm.cuh"
|
||||
#include "../quant/hadamard.cuh"
|
||||
#include "../quant/reconstruct.cuh"
|
||||
#include "../quant/exl3_devctx.cuh"
|
||||
#include "../activation.cuh"
|
||||
#include "../add.cuh"
|
||||
|
||||
@@ -336,10 +337,11 @@ BC_BlockSparseMLP::BC_BlockSparseMLP
|
||||
use_mgemm = gate_K == up_K;
|
||||
}
|
||||
|
||||
void BC_BlockSparseMLP::run_single_expert
|
||||
void BC_BlockSparseMLP::run_single_expert_gr
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx
|
||||
const int expert_idx,
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
@@ -347,45 +349,45 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
at::Tensor ai = interm_a2.slice(0, 0, bsz);
|
||||
at::Tensor oi = out_d2.slice(0, 0, bsz);
|
||||
|
||||
if (use_mgemm)
|
||||
{
|
||||
at::Tensor yb = y.unsqueeze(0);
|
||||
at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
|
||||
at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
|
||||
|
||||
exl3_mgemm
|
||||
(
|
||||
yb,
|
||||
gu_trellis_ptr[expert_idx],
|
||||
gui,
|
||||
gu_suh_ptr[expert_idx],
|
||||
yh2,
|
||||
gu_svh_ptr[expert_idx],
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
gate_K,
|
||||
-1,
|
||||
gate_mcg,
|
||||
gate_mul1,
|
||||
-1,
|
||||
-1,
|
||||
0
|
||||
);
|
||||
|
||||
at::Tensor gi = gui[0];
|
||||
at::Tensor ui = gui[1];
|
||||
|
||||
if (act_silu)
|
||||
silu_mul(gi, ui, ai, act_limit);
|
||||
else if (act_gelu)
|
||||
gelu_mul(gi, ui, ai, act_limit);
|
||||
}
|
||||
else
|
||||
// if (use_mgemm)
|
||||
// {
|
||||
// at::Tensor yb = y.unsqueeze(0);
|
||||
// at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)});
|
||||
// at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)});
|
||||
//
|
||||
// exl3_mgemm
|
||||
// (
|
||||
// yb,
|
||||
// gu_trellis_ptr[expert_idx],
|
||||
// gui,
|
||||
// gu_suh_ptr[expert_idx],
|
||||
// yh2,
|
||||
// gu_svh_ptr[expert_idx],
|
||||
// c10::nullopt,
|
||||
// c10::nullopt,
|
||||
// gate_K,
|
||||
// -1,
|
||||
// gate_mcg,
|
||||
// gate_mul1,
|
||||
// -1,
|
||||
// -1,
|
||||
// 0
|
||||
// );
|
||||
//
|
||||
// at::Tensor gi = gui[0];
|
||||
// at::Tensor ui = gui[1];
|
||||
//
|
||||
// if (act_silu)
|
||||
// silu_mul(gi, ui, ai, act_limit);
|
||||
// else if (act_gelu)
|
||||
// gelu_mul(gi, ui, ai, act_limit);
|
||||
// }
|
||||
// else
|
||||
{
|
||||
at::Tensor gi = interm_gu.slice(0, 0, bsz);
|
||||
at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2);
|
||||
|
||||
exl3_gemm
|
||||
exl3_gemm_gr
|
||||
(
|
||||
y,
|
||||
gates[expert_idx]->trellis,
|
||||
@@ -396,10 +398,11 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
-1,
|
||||
gate_mcg,
|
||||
gate_mul1,
|
||||
0
|
||||
0,
|
||||
graph
|
||||
);
|
||||
|
||||
exl3_gemm
|
||||
exl3_gemm_gr
|
||||
(
|
||||
y,
|
||||
ups[expert_idx]->trellis,
|
||||
@@ -410,16 +413,17 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
-1,
|
||||
up_mcg,
|
||||
up_mul1,
|
||||
0
|
||||
0,
|
||||
graph
|
||||
);
|
||||
|
||||
if (act_silu)
|
||||
silu_mul(gi, ui, ai, act_limit);
|
||||
silu_mul_gr(gi, ui, ai, act_limit, graph);
|
||||
else if (act_gelu)
|
||||
gelu_mul(gi, ui, ai, act_limit);
|
||||
gelu_mul_gr(gi, ui, ai, act_limit, graph);
|
||||
}
|
||||
|
||||
exl3_gemm
|
||||
exl3_gemm_gr
|
||||
(
|
||||
ai,
|
||||
downs[expert_idx]->trellis,
|
||||
@@ -430,10 +434,64 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
-1,
|
||||
down_mcg,
|
||||
down_mul1,
|
||||
0
|
||||
0,
|
||||
graph
|
||||
);
|
||||
}
|
||||
|
||||
void BC_BlockSparseMLP::run_single_expert
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx
|
||||
)
|
||||
{
|
||||
int bsz = y.size(0);
|
||||
TORCH_CHECK(bsz <= TEMP_ROWS);
|
||||
int graphidx = bsz - 1;
|
||||
|
||||
c10::cuda::CUDAGuard device_guard(y.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
#define USE_GRAPH
|
||||
#ifndef USE_GRAPH
|
||||
|
||||
run_single_expert_gr(y, expert_idx, nullptr);
|
||||
|
||||
#else
|
||||
|
||||
if (!graph_single[graphidx].ready)
|
||||
{
|
||||
prepare_ctx(y.get_device());
|
||||
|
||||
graph_single[graphidx].capture_begin();
|
||||
run_single_expert_gr(y, expert_idx, &graph_single[graphidx]);
|
||||
graph_single[graphidx].capture_end();
|
||||
}
|
||||
|
||||
auto args = std::vector<PPTR>
|
||||
{
|
||||
PPTR(GP_gemm_A, (void*) y.data_ptr()),
|
||||
PPTR(GP_gemm_B_trellis, (void*) gates[expert_idx]->trellis.data_ptr()),
|
||||
PPTR(GP_gemm_B_suh, (void*) gates[expert_idx]->suh.data_ptr()),
|
||||
PPTR(GP_gemm_B_svh, (void*) gates[expert_idx]->svh.data_ptr()),
|
||||
PPTR(GP_end, nullptr),
|
||||
PPTR(GP_gemm_A, (void*) y.data_ptr()),
|
||||
PPTR(GP_gemm_B_trellis, (void*) ups[expert_idx]->trellis.data_ptr()),
|
||||
PPTR(GP_gemm_B_suh, (void*) ups[expert_idx]->suh.data_ptr()),
|
||||
PPTR(GP_gemm_B_svh, (void*) ups[expert_idx]->svh.data_ptr()),
|
||||
PPTR(GP_end, nullptr),
|
||||
PPTR(GP_gemm_B_trellis, (void*) downs[expert_idx]->trellis.data_ptr()),
|
||||
PPTR(GP_gemm_B_suh, (void*) downs[expert_idx]->suh.data_ptr()),
|
||||
PPTR(GP_gemm_B_svh, (void*) downs[expert_idx]->svh.data_ptr()),
|
||||
};
|
||||
|
||||
graph_single[graphidx].launch(args, stream);
|
||||
|
||||
#endif
|
||||
#undef USE_GRAPH
|
||||
}
|
||||
|
||||
|
||||
void BC_BlockSparseMLP::run_single_expert_dq
|
||||
(
|
||||
const at::Tensor& y,
|
||||
|
||||
@@ -10,6 +10,7 @@ namespace py = pybind11;
|
||||
#include "../graph.cuh"
|
||||
|
||||
#define MAX_EXPERTS 512
|
||||
#define TEMP_ROWS 32
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
|
||||
int bsz,
|
||||
@@ -75,6 +76,7 @@ struct BC_BlockSparseMLP
|
||||
bool use_mgemm;
|
||||
|
||||
Graph graph_bsz1;
|
||||
Graph graph_single[TEMP_ROWS];
|
||||
|
||||
BC_BlockSparseMLP
|
||||
(
|
||||
@@ -139,6 +141,13 @@ struct BC_BlockSparseMLP
|
||||
at::Tensor& routing_weights
|
||||
);
|
||||
|
||||
void run_single_expert_gr
|
||||
(
|
||||
const at::Tensor& y,
|
||||
const int expert_idx,
|
||||
Graph* graph
|
||||
);
|
||||
|
||||
void run_single_expert
|
||||
(
|
||||
const at::Tensor& y,
|
||||
|
||||
@@ -76,4 +76,11 @@ int g_get_cc(int device)
|
||||
int g_get_num_sms(int device)
|
||||
{
|
||||
return DevCtx::instance().get_num_sms(device);
|
||||
}
|
||||
}
|
||||
|
||||
void prepare_ctx(int device)
|
||||
{
|
||||
DevCtx::instance().get_num_sms(device);
|
||||
DevCtx::instance().get_cc(device);
|
||||
DevCtx::instance().get_locks(device);
|
||||
}
|
||||
|
||||
@@ -42,3 +42,5 @@ private:
|
||||
|
||||
int g_get_cc(int device);
|
||||
int g_get_num_sms(int device);
|
||||
|
||||
void prepare_ctx(int device);
|
||||
@@ -156,7 +156,11 @@ int exl3_gemm_gr
|
||||
);
|
||||
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_A, 0);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_trellis, 1);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_C, 2);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_suh, 7);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_A_had, 8);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_svh, 9);
|
||||
if (graph) graph->record_param((void*) kernel, GP_end, 0);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
|
||||
Reference in New Issue
Block a user