diff --git a/exllamav3/exllamav3_ext/graph.cuh b/exllamav3/exllamav3_ext/graph.cuh index 63c766a..59236eb 100644 --- a/exllamav3/exllamav3_ext/graph.cuh +++ b/exllamav3/exllamav3_ext/graph.cuh @@ -14,7 +14,11 @@ enum GraphedParams GP_end, GP_gemm_A, + GP_gemm_B_trellis, GP_gemm_C, + GP_gemm_B_suh, + GP_gemm_A_had, + GP_gemm_B_svh, GP_mgemm, GP_mgemm_A, diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp index 48d747c..e2c26f0 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp @@ -8,6 +8,7 @@ #include "../quant/exl3_gemm.cuh" #include "../quant/hadamard.cuh" #include "../quant/reconstruct.cuh" +#include "../quant/exl3_devctx.cuh" #include "../activation.cuh" #include "../add.cuh" @@ -336,10 +337,11 @@ BC_BlockSparseMLP::BC_BlockSparseMLP use_mgemm = gate_K == up_K; } -void BC_BlockSparseMLP::run_single_expert +void BC_BlockSparseMLP::run_single_expert_gr ( const at::Tensor& y, - const int expert_idx + const int expert_idx, + Graph* graph ) { int bsz = y.size(0); @@ -347,45 +349,45 @@ void BC_BlockSparseMLP::run_single_expert at::Tensor ai = interm_a2.slice(0, 0, bsz); at::Tensor oi = out_d2.slice(0, 0, bsz); - if (use_mgemm) - { - at::Tensor yb = y.unsqueeze(0); - at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)}); - at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)}); - - exl3_mgemm - ( - yb, - gu_trellis_ptr[expert_idx], - gui, - gu_suh_ptr[expert_idx], - yh2, - gu_svh_ptr[expert_idx], - c10::nullopt, - c10::nullopt, - gate_K, - -1, - gate_mcg, - gate_mul1, - -1, - -1, - 0 - ); - - at::Tensor gi = gui[0]; - at::Tensor ui = gui[1]; - - if (act_silu) - silu_mul(gi, ui, ai, act_limit); - else if (act_gelu) - gelu_mul(gi, ui, ai, act_limit); - } - else +// if (use_mgemm) +// { +// at::Tensor yb = y.unsqueeze(0); +// at::Tensor gui = interm_gu.slice(0, 0, bsz * 2).view({2, bsz, interm_gu.size(1)}); +// at::Tensor yh2i = yh2.slice(0, 0, 2).view({2, 1, yh2.size(1)}); +// +// exl3_mgemm +// ( +// yb, +// gu_trellis_ptr[expert_idx], +// gui, +// gu_suh_ptr[expert_idx], +// yh2, +// gu_svh_ptr[expert_idx], +// c10::nullopt, +// c10::nullopt, +// gate_K, +// -1, +// gate_mcg, +// gate_mul1, +// -1, +// -1, +// 0 +// ); +// +// at::Tensor gi = gui[0]; +// at::Tensor ui = gui[1]; +// +// if (act_silu) +// silu_mul(gi, ui, ai, act_limit); +// else if (act_gelu) +// gelu_mul(gi, ui, ai, act_limit); +// } +// else { at::Tensor gi = interm_gu.slice(0, 0, bsz); at::Tensor ui = interm_gu.slice(0, bsz, bsz * 2); - exl3_gemm + exl3_gemm_gr ( y, gates[expert_idx]->trellis, @@ -396,10 +398,11 @@ void BC_BlockSparseMLP::run_single_expert -1, gate_mcg, gate_mul1, - 0 + 0, + graph ); - exl3_gemm + exl3_gemm_gr ( y, ups[expert_idx]->trellis, @@ -410,16 +413,17 @@ void BC_BlockSparseMLP::run_single_expert -1, up_mcg, up_mul1, - 0 + 0, + graph ); if (act_silu) - silu_mul(gi, ui, ai, act_limit); + silu_mul_gr(gi, ui, ai, act_limit, graph); else if (act_gelu) - gelu_mul(gi, ui, ai, act_limit); + gelu_mul_gr(gi, ui, ai, act_limit, graph); } - exl3_gemm + exl3_gemm_gr ( ai, downs[expert_idx]->trellis, @@ -430,10 +434,64 @@ void BC_BlockSparseMLP::run_single_expert -1, down_mcg, down_mul1, - 0 + 0, + graph ); } +void BC_BlockSparseMLP::run_single_expert +( + const at::Tensor& y, + const int expert_idx +) +{ + int bsz = y.size(0); + TORCH_CHECK(bsz <= TEMP_ROWS); + int graphidx = bsz - 1; + + c10::cuda::CUDAGuard device_guard(y.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + #define USE_GRAPH + #ifndef USE_GRAPH + + run_single_expert_gr(y, expert_idx, nullptr); + + #else + + if (!graph_single[graphidx].ready) + { + prepare_ctx(y.get_device()); + + graph_single[graphidx].capture_begin(); + run_single_expert_gr(y, expert_idx, &graph_single[graphidx]); + graph_single[graphidx].capture_end(); + } + + auto args = std::vector + { + PPTR(GP_gemm_A, (void*) y.data_ptr()), + PPTR(GP_gemm_B_trellis, (void*) gates[expert_idx]->trellis.data_ptr()), + PPTR(GP_gemm_B_suh, (void*) gates[expert_idx]->suh.data_ptr()), + PPTR(GP_gemm_B_svh, (void*) gates[expert_idx]->svh.data_ptr()), + PPTR(GP_end, nullptr), + PPTR(GP_gemm_A, (void*) y.data_ptr()), + PPTR(GP_gemm_B_trellis, (void*) ups[expert_idx]->trellis.data_ptr()), + PPTR(GP_gemm_B_suh, (void*) ups[expert_idx]->suh.data_ptr()), + PPTR(GP_gemm_B_svh, (void*) ups[expert_idx]->svh.data_ptr()), + PPTR(GP_end, nullptr), + PPTR(GP_gemm_B_trellis, (void*) downs[expert_idx]->trellis.data_ptr()), + PPTR(GP_gemm_B_suh, (void*) downs[expert_idx]->suh.data_ptr()), + PPTR(GP_gemm_B_svh, (void*) downs[expert_idx]->svh.data_ptr()), + }; + + graph_single[graphidx].launch(args, stream); + + #endif + #undef USE_GRAPH +} + + void BC_BlockSparseMLP::run_single_expert_dq ( const at::Tensor& y, diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h index 2af0f21..f5c0ffd 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h @@ -10,6 +10,7 @@ namespace py = pybind11; #include "../graph.cuh" #define MAX_EXPERTS 512 +#define TEMP_ROWS 32 std::tuple blocksparse_mlp_routing( int bsz, @@ -75,6 +76,7 @@ struct BC_BlockSparseMLP bool use_mgemm; Graph graph_bsz1; + Graph graph_single[TEMP_ROWS]; BC_BlockSparseMLP ( @@ -139,6 +141,13 @@ struct BC_BlockSparseMLP at::Tensor& routing_weights ); + void run_single_expert_gr + ( + const at::Tensor& y, + const int expert_idx, + Graph* graph + ); + void run_single_expert ( const at::Tensor& y, diff --git a/exllamav3/exllamav3_ext/quant/exl3_devctx.cu b/exllamav3/exllamav3_ext/quant/exl3_devctx.cu index d00adb1..620b23c 100644 --- a/exllamav3/exllamav3_ext/quant/exl3_devctx.cu +++ b/exllamav3/exllamav3_ext/quant/exl3_devctx.cu @@ -76,4 +76,11 @@ int g_get_cc(int device) int g_get_num_sms(int device) { return DevCtx::instance().get_num_sms(device); -} \ No newline at end of file +} + +void prepare_ctx(int device) +{ + DevCtx::instance().get_num_sms(device); + DevCtx::instance().get_cc(device); + DevCtx::instance().get_locks(device); +} diff --git a/exllamav3/exllamav3_ext/quant/exl3_devctx.cuh b/exllamav3/exllamav3_ext/quant/exl3_devctx.cuh index 4cb17cd..ac8efae 100644 --- a/exllamav3/exllamav3_ext/quant/exl3_devctx.cuh +++ b/exllamav3/exllamav3_ext/quant/exl3_devctx.cuh @@ -42,3 +42,5 @@ private: int g_get_cc(int device); int g_get_num_sms(int device); + +void prepare_ctx(int device); \ No newline at end of file diff --git a/exllamav3/exllamav3_ext/quant/exl3_gemm.cu b/exllamav3/exllamav3_ext/quant/exl3_gemm.cu index 6451ab9..daaca14 100644 --- a/exllamav3/exllamav3_ext/quant/exl3_gemm.cu +++ b/exllamav3/exllamav3_ext/quant/exl3_gemm.cu @@ -156,7 +156,11 @@ int exl3_gemm_gr ); if (graph) graph->record_param((void*) kernel, GP_gemm_A, 0); + if (graph) graph->record_param((void*) kernel, GP_gemm_B_trellis, 1); if (graph) graph->record_param((void*) kernel, GP_gemm_C, 2); + if (graph) graph->record_param((void*) kernel, GP_gemm_B_suh, 7); + if (graph) graph->record_param((void*) kernel, GP_gemm_A_had, 8); + if (graph) graph->record_param((void*) kernel, GP_gemm_B_svh, 9); if (graph) graph->record_param((void*) kernel, GP_end, 0); cuda_check(cudaPeekAtLastError());