mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
Rework GEMM kernel tuning
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include "quant/exl3_gemm.cuh"
|
||||
#include "quant/exl3_kernel_map.cuh"
|
||||
#include "quant/util.cuh"
|
||||
#include "quant/exl3_devctx.cuh"
|
||||
|
||||
#include "generator/strings.h"
|
||||
#include "generator/sampling_basic.cuh"
|
||||
@@ -87,6 +88,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("exl3_gemm", &exl3_gemm, "exl3_gemm");
|
||||
m.def("exl3_gemm_num_kernel_shapes", &exl3_gemm_num_kernel_shapes, "exl3_gemm_num_kernel_shapes");
|
||||
m.def("exl3_gemm_shape_compat", &exl3_gemm_shape_compat, "exl3_gemm_shape_compat");
|
||||
m.def("g_get_cc", &g_get_cc, "g_get_cc");
|
||||
m.def("g_get_num_sms", &g_get_num_sms, "g_get_num_sms");
|
||||
m.def("exl3_mgemm", &exl3_mgemm, "exl3_mgemm");
|
||||
m.def("hgemm", &hgemm, "hgemm");
|
||||
m.def("rope", &rope, "rope");
|
||||
|
||||
@@ -85,7 +85,8 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
gate_mcg_mult,
|
||||
gate_mul1_mult,
|
||||
min_expert,
|
||||
max_expert
|
||||
max_expert,
|
||||
0
|
||||
);
|
||||
|
||||
exl3_mgemm(
|
||||
@@ -102,7 +103,8 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
up_mcg_mult,
|
||||
up_mul1_mult,
|
||||
min_expert,
|
||||
max_expert
|
||||
max_expert,
|
||||
0
|
||||
);
|
||||
|
||||
if (act_silu)
|
||||
@@ -124,7 +126,8 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
down_mcg_mult,
|
||||
down_mul1_mult,
|
||||
min_expert,
|
||||
max_expert
|
||||
max_expert,
|
||||
0
|
||||
);
|
||||
|
||||
if (shared_experts)
|
||||
|
||||
@@ -31,12 +31,12 @@ void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y)
|
||||
{
|
||||
if (x.numel() == x.size(-1))
|
||||
{
|
||||
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult);
|
||||
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
at::Tensor xh_ = at::empty_like(x);
|
||||
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult);
|
||||
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult, 0);
|
||||
}
|
||||
|
||||
if (bias) y.add_(bias.value());
|
||||
|
||||
@@ -31,7 +31,8 @@ void BC_GatedMLP::run_bsz1
|
||||
gu_mcg_mult,
|
||||
gu_mul1_mult,
|
||||
-1,
|
||||
-1
|
||||
-1,
|
||||
0
|
||||
);
|
||||
|
||||
at::Tensor g = gu.select(0, 0).unsqueeze(0);
|
||||
|
||||
@@ -55,4 +55,14 @@ int* DevCtx::get_locks(int device)
|
||||
cudaMemset(locks[device], 0, MAX_TILES_C * sizeof(int));
|
||||
}
|
||||
return (int*) locks[device];
|
||||
}
|
||||
|
||||
int g_get_cc(int device)
|
||||
{
|
||||
return DevCtx::instance().get_cc(device);
|
||||
}
|
||||
|
||||
int g_get_num_sms(int device)
|
||||
{
|
||||
return DevCtx::instance().get_num_sms(device);
|
||||
}
|
||||
@@ -6,12 +6,13 @@
|
||||
// Max allowable output size, in tiles. Used to allocate global lock buffer per device for sync across threadblocks
|
||||
#define MAX_TILES_C (1024 * 1024)
|
||||
|
||||
// Treat hopper and blackwell as same arch for now
|
||||
#define MAX_DEVICES 32
|
||||
#define CC_OLD 1
|
||||
#define CC_AMPERE 2
|
||||
#define CC_ADA 3
|
||||
#define CC_HOPPER 4
|
||||
#define CC_BLACKWELL 5
|
||||
#define CC_BLACKWELL 4
|
||||
|
||||
// Singleton to manage context for each device. Stores device attributes and a large-enough lock buffer per device
|
||||
class DevCtx
|
||||
@@ -32,4 +33,7 @@ private:
|
||||
DevCtx() = default;
|
||||
DevCtx(const DevCtx&) = delete;
|
||||
DevCtx& operator=(const DevCtx&) = delete;
|
||||
};
|
||||
};
|
||||
|
||||
int g_get_cc(int device);
|
||||
int g_get_num_sms(int device);
|
||||
|
||||
@@ -12,11 +12,17 @@ namespace cg = cooperative_groups;
|
||||
#include "exl3_devctx.cuh"
|
||||
#include <set>
|
||||
|
||||
#define NEW_TUNE_GEMM
|
||||
#define NEW_TUNE_MGEMM
|
||||
|
||||
int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
|
||||
int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
|
||||
|
||||
/*
|
||||
EXL3 matmul, A @ B -> C
|
||||
|
||||
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
|
||||
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
|
||||
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
|
||||
- C: empty row-major C tensor, shape (m, n), dtype float16 or float32, contiguous. Does not need to be zero-initialized
|
||||
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
|
||||
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
|
||||
@@ -39,7 +45,8 @@ int exl3_gemm
|
||||
const c10::optional<at::Tensor>& svh,
|
||||
int force_shape_idx,
|
||||
uint32_t mcg_mult,
|
||||
uint32_t mul1_mult
|
||||
uint32_t mul1_mult,
|
||||
int force_num_sms
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(A.device());
|
||||
@@ -48,7 +55,7 @@ int exl3_gemm
|
||||
TORCH_CHECK_DIM(B, 3);
|
||||
TORCH_CHECK_SHAPES(A, -1, B, 0, 16);
|
||||
TORCH_CHECK_SHAPES(C, -1, B, 1, 16);
|
||||
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
|
||||
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
|
||||
TORCH_CHECK_DTYPE(A, kHalf);
|
||||
TORCH_CHECK_DTYPE(B, kShort);
|
||||
bool c_fp32 = C.dtype() == at::kFloat;
|
||||
@@ -59,26 +66,26 @@ int exl3_gemm
|
||||
half* A_had_ptr = nullptr;
|
||||
if (suh_ptr)
|
||||
{
|
||||
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
|
||||
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
|
||||
A_had_ptr = (half*) OPTPTR(A_had);
|
||||
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
|
||||
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
|
||||
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
|
||||
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
|
||||
}
|
||||
|
||||
// Get SV, optionally
|
||||
const half* svh_ptr = (const half*) OPTPTR(svh);
|
||||
// if (svh_ptr)
|
||||
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
|
||||
// if (svh_ptr)
|
||||
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
|
||||
|
||||
// Device properties
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int num_sms = DevCtx::instance().get_num_sms(device);
|
||||
int num_sms = force_num_sms ? force_num_sms : DevCtx::instance().get_num_sms(device);
|
||||
int cc = DevCtx::instance().get_cc(device);
|
||||
int* locks = DevCtx::instance().get_locks(device);
|
||||
|
||||
// Dispatch
|
||||
int bits = B.size(2) / 16;
|
||||
int K = B.size(2) / 16;
|
||||
const half* A_ptr = (const half*) A.data_ptr();
|
||||
const uint16_t* B_ptr = (const uint16_t*) B.data_ptr();
|
||||
void* C_ptr = (void*) C.data_ptr();
|
||||
@@ -96,21 +103,33 @@ int exl3_gemm
|
||||
if (mcg_mult) { cb = 1; mult = mcg_mult; }
|
||||
if (mul1_mult) { cb = 2; mult = mul1_mult; }
|
||||
|
||||
int selected_shape;
|
||||
int block_dim;
|
||||
fp_exl3_gemm_kernel kernel = select_exl3_gemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, bits, c_fp32,
|
||||
force_shape_idx, &block_dim, &selected_shape,
|
||||
&num_sms, cb
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
int shape_idx;
|
||||
fp_exl3_gemm_kernel kernel;
|
||||
|
||||
#ifndef NEW_TUNE_GEMM
|
||||
kernel = select_exl3_gemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
#else
|
||||
TResult* tr = select_exl3_gemm_mgemm_kernel_new(cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
|
||||
if (!tr) return 0;
|
||||
num_sms = MIN(num_sms, tr->num_sms);
|
||||
kernel = tr->kernel;
|
||||
block_dim = tr->block_dim;
|
||||
shape_idx = tr->shape_idx;
|
||||
#endif
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*)kernel);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
void* kernelArgs[] =
|
||||
{
|
||||
@@ -128,7 +147,7 @@ int exl3_gemm
|
||||
};
|
||||
cudaLaunchCooperativeKernel
|
||||
(
|
||||
(void*)kernel,
|
||||
(void*) kernel,
|
||||
num_sms,
|
||||
block_dim,
|
||||
kernelArgs,
|
||||
@@ -136,14 +155,16 @@ int exl3_gemm
|
||||
stream
|
||||
);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return selected_shape;
|
||||
|
||||
// return selected_shape;
|
||||
return shape_idx;
|
||||
}
|
||||
|
||||
/*
|
||||
EXL3 multi matmul, A @ B -> C
|
||||
|
||||
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
|
||||
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
|
||||
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
|
||||
- C: empty row-major C tensor, shape (m, n), dtype float16 or float23, contiguous. Does not need to be zero-initialized
|
||||
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
|
||||
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
|
||||
@@ -169,7 +190,8 @@ int exl3_mgemm
|
||||
uint32_t mcg_mult,
|
||||
uint32_t mul1_mult,
|
||||
int min_index,
|
||||
int max_index
|
||||
int max_index,
|
||||
int force_num_sms
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(A.device());
|
||||
@@ -194,6 +216,7 @@ int exl3_mgemm
|
||||
int bsz = A.size(1);
|
||||
int bszm_in = A.size(0);
|
||||
int bszm_out = C.size(0);
|
||||
int bszm = MAX(bszm_in, bszm_out);
|
||||
|
||||
const long* indices_ptr = (const long*) OPTPTR(indices);
|
||||
const half* weights_ptr = (const half*) OPTPTR(weights);
|
||||
@@ -219,8 +242,8 @@ int exl3_mgemm
|
||||
// Device properties
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int num_sms = DevCtx::instance().get_num_sms(device);
|
||||
int total_sms = num_sms;
|
||||
int total_sms = DevCtx::instance().get_num_sms(device);
|
||||
int num_sms = force_num_sms ? force_num_sms : total_sms;
|
||||
int cc = DevCtx::instance().get_cc(device);
|
||||
int* locks = DevCtx::instance().get_locks(device);
|
||||
|
||||
@@ -239,25 +262,44 @@ int exl3_mgemm
|
||||
if (mcg_mult) { cb = 1; mult = mcg_mult; }
|
||||
if (mul1_mult) { cb = 2; mult = mul1_mult; }
|
||||
|
||||
int selected_shape;
|
||||
int shape_idx;
|
||||
int block_dim;
|
||||
fp_exl3_mgemm_kernel kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &selected_shape,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
fp_exl3_mgemm_kernel kernel;
|
||||
int concurrency;
|
||||
|
||||
#ifndef NEW_TUNE_MGEMM
|
||||
kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
concurrency = MIN(total_sms / num_sms, bszm_out);
|
||||
#else
|
||||
kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
|
||||
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
|
||||
num_sms = tiles;
|
||||
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
|
||||
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
|
||||
concurrency = MIN(total_sms / num_sms, bszm);
|
||||
#endif
|
||||
|
||||
// Launch bigger grid if possible
|
||||
int concurrency = MIN(total_sms / num_sms, bszm_out);
|
||||
dim3 block_grid(num_sms, 1, concurrency);
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*)kernel);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
}
|
||||
void* kernelArgs[] =
|
||||
{
|
||||
@@ -279,10 +321,9 @@ int exl3_mgemm
|
||||
(void*)& min_index,
|
||||
(void*)& max_index
|
||||
};
|
||||
|
||||
cudaLaunchCooperativeKernel
|
||||
(
|
||||
(void*)kernel,
|
||||
(void*) kernel,
|
||||
block_grid,
|
||||
block_dim,
|
||||
kernelArgs,
|
||||
@@ -290,5 +331,5 @@ int exl3_mgemm
|
||||
stream
|
||||
);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return selected_shape;
|
||||
return shape_idx;
|
||||
}
|
||||
@@ -12,7 +12,8 @@ int exl3_gemm
|
||||
const c10::optional<at::Tensor>& svh,
|
||||
int force_shape_idx,
|
||||
uint32_t mcg_mult,
|
||||
uint32_t mul1_mult
|
||||
uint32_t mul1_mult,
|
||||
int force_num_sms
|
||||
);
|
||||
|
||||
int exl3_mgemm
|
||||
@@ -30,5 +31,6 @@ int exl3_mgemm
|
||||
uint32_t mcg_mult,
|
||||
uint32_t mul1_mult,
|
||||
int min_index,
|
||||
int max_index
|
||||
int max_index,
|
||||
int force_num_sms
|
||||
);
|
||||
|
||||
@@ -8,6 +8,7 @@ namespace cg = cooperative_groups;
|
||||
#include "../ptx.cuh"
|
||||
#include <tuple>
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
#include "exl3_kernel_map.cuh"
|
||||
#include "exl3_devctx.cuh"
|
||||
#include "comp_units/exl3_comp_unit_1.cuh"
|
||||
@@ -19,7 +20,10 @@ namespace cg = cooperative_groups;
|
||||
#include "comp_units/exl3_comp_unit_7.cuh"
|
||||
#include "comp_units/exl3_comp_unit_8.cuh"
|
||||
|
||||
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool multi, int bszm_in, int bszm_out)
|
||||
#include "exl3_kernel_map_samples.cuh"
|
||||
std::map<uint64_t, TResult> _tuning_cache = {};
|
||||
|
||||
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int K, bool multi, int bszm_in, int bszm_out)
|
||||
{
|
||||
bool mod_256 = (size_n % 256 == 0);
|
||||
bool mod_512 = (size_n % 512 == 0);
|
||||
@@ -31,18 +35,18 @@ int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool
|
||||
{
|
||||
case CC_OLD:
|
||||
case CC_AMPERE:
|
||||
if (mod_256 && bits <= 4)
|
||||
if (mod_256 && K <= 4)
|
||||
{
|
||||
if (size_n <= 2048 || size_k <= 2048) return 2;
|
||||
return 3;
|
||||
}
|
||||
if (mod_256 && size_n < 4096) return size_k > 8192 ? 3 : 2;
|
||||
if (mod_512 && (size_n * size_k) > (4096 * 4096) && bits <= 6) return 4;
|
||||
if (mod_512 && (size_n * size_k) > (4096 * 4096) && K <= 6) return 4;
|
||||
if (mod_256) return 3;
|
||||
return 2;
|
||||
|
||||
case CC_ADA:
|
||||
if (mod_256 && bits <= 3)
|
||||
if (mod_256 && K <= 3)
|
||||
{
|
||||
if (size_k <= 2048 && !multi) return 2;
|
||||
if (size_n < 4096 && size_k <= 12288) return 2;
|
||||
@@ -53,19 +57,19 @@ int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool
|
||||
if (mod_256) return 3;
|
||||
return 2;
|
||||
|
||||
case CC_HOPPER:
|
||||
// case CC_HOPPER:
|
||||
case CC_BLACKWELL:
|
||||
if ((bits == 4 || bits == 2) && !multi)
|
||||
if ((K == 4 || K == 2) && !multi)
|
||||
{
|
||||
if (size_k <= 2048) return 1;
|
||||
}
|
||||
if (bits >= 7)
|
||||
if (K >= 7)
|
||||
{
|
||||
if (mod_256 && size_n <= 8192) return size_k > 32768 ? 3 : 2;
|
||||
if (mod_512 && size_n > 32768) return 4;
|
||||
return 2;
|
||||
}
|
||||
if (mod_256 && size_n <= 4096) return size_k > 8192 && bits >= 3 ? 3 : 2;
|
||||
if (mod_256 && size_n <= 4096) return size_k > 8192 && K >= 3 ? 3 : 2;
|
||||
if (mod_512 && size_n > 16384) return 4;
|
||||
if (mod_256) return 3;
|
||||
return 2;
|
||||
@@ -82,7 +86,7 @@ int exl3_gemm_tilesize_k[] = {EXL3_GEMM_TILESIZE_K};
|
||||
int exl3_gemm_tilesize_n[] = {EXL3_GEMM_TILESIZE_N};
|
||||
int exl3_gemm_blockdim[] = {EXL3_GEMM_BLOCKDIM};
|
||||
|
||||
bool exl3_gemm_shape_compat(int shape_idx, int size_m, int size_k, int size_n, int bits)
|
||||
bool exl3_gemm_shape_compat(int shape_idx, int size_m, int size_k, int size_n, int K)
|
||||
{
|
||||
int tilesize_k = exl3_gemm_tilesize_k[shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n[shape_idx];
|
||||
@@ -95,7 +99,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int bits,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int* out_block_dim,
|
||||
@@ -104,7 +108,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
|
||||
int cb
|
||||
)
|
||||
{
|
||||
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, bits, false, 1, 1) : force_shape_idx;
|
||||
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, K, false, 1, 1) : force_shape_idx;
|
||||
|
||||
TORCH_CHECK(shape_idx > 0, "exl3_gemm: no compatible kernel");
|
||||
if (out_shape_idx) *out_shape_idx = shape_idx;
|
||||
@@ -123,7 +127,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
|
||||
|
||||
if (c_fp32)
|
||||
{
|
||||
switch (bits)
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_gemm_kernel_fp32_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_gemm_kernel_fp32_b2[kernel_idx];
|
||||
@@ -138,7 +142,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (bits)
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_gemm_kernel_fp16_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_gemm_kernel_fp16_b2[kernel_idx];
|
||||
@@ -159,7 +163,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int bits,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int* out_block_dim,
|
||||
@@ -170,7 +174,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
int bszm_out
|
||||
)
|
||||
{
|
||||
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, bits, true, bszm_in, bszm_out) : force_shape_idx;
|
||||
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, K, true, bszm_in, bszm_out) : force_shape_idx;
|
||||
TORCH_CHECK(shape_idx > 0, "exl3_mgemm: no compatible kernel");
|
||||
if (out_shape_idx) *out_shape_idx = shape_idx;
|
||||
if (out_block_dim) *out_block_dim = exl3_gemm_blockdim[shape_idx];
|
||||
@@ -188,7 +192,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
|
||||
if (c_fp32)
|
||||
{
|
||||
switch (bits)
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_mgemm_kernel_fp32_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_mgemm_kernel_fp32_b2[kernel_idx];
|
||||
@@ -203,7 +207,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (bits)
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_mgemm_kernel_fp16_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_mgemm_kernel_fp16_b2[kernel_idx];
|
||||
@@ -216,4 +220,160 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
default: TORCH_CHECK(false, "No kernel for GEMM shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fp_exl3_gemm_kernel get_gemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb)
|
||||
{
|
||||
int kernel_idx = shape_idx + (EXL3_GEMM_NUM_SHAPES + 1) * cb;
|
||||
|
||||
if (c_fp32)
|
||||
{
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_gemm_kernel_fp32_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_gemm_kernel_fp32_b2[kernel_idx];
|
||||
case 3: return tfp_exl3_gemm_kernel_fp32_b3[kernel_idx];
|
||||
case 4: return tfp_exl3_gemm_kernel_fp32_b4[kernel_idx];
|
||||
case 5: return tfp_exl3_gemm_kernel_fp32_b5[kernel_idx];
|
||||
case 6: return tfp_exl3_gemm_kernel_fp32_b6[kernel_idx];
|
||||
case 7: return tfp_exl3_gemm_kernel_fp32_b7[kernel_idx];
|
||||
case 8: return tfp_exl3_gemm_kernel_fp32_b8[kernel_idx];
|
||||
default: TORCH_CHECK(false, "No kernel for GEMM shape");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_gemm_kernel_fp16_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_gemm_kernel_fp16_b2[kernel_idx];
|
||||
case 3: return tfp_exl3_gemm_kernel_fp16_b3[kernel_idx];
|
||||
case 4: return tfp_exl3_gemm_kernel_fp16_b4[kernel_idx];
|
||||
case 5: return tfp_exl3_gemm_kernel_fp16_b5[kernel_idx];
|
||||
case 6: return tfp_exl3_gemm_kernel_fp16_b6[kernel_idx];
|
||||
case 7: return tfp_exl3_gemm_kernel_fp16_b7[kernel_idx];
|
||||
case 8: return tfp_exl3_gemm_kernel_fp16_b8[kernel_idx];
|
||||
default: TORCH_CHECK(false, "No kernel for GEMM shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb)
|
||||
{
|
||||
int kernel_idx = shape_idx + (EXL3_GEMM_NUM_SHAPES + 1) * cb;
|
||||
|
||||
if (c_fp32)
|
||||
{
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_mgemm_kernel_fp32_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_mgemm_kernel_fp32_b2[kernel_idx];
|
||||
case 3: return tfp_exl3_mgemm_kernel_fp32_b3[kernel_idx];
|
||||
case 4: return tfp_exl3_mgemm_kernel_fp32_b4[kernel_idx];
|
||||
case 5: return tfp_exl3_mgemm_kernel_fp32_b5[kernel_idx];
|
||||
case 6: return tfp_exl3_mgemm_kernel_fp32_b6[kernel_idx];
|
||||
case 7: return tfp_exl3_mgemm_kernel_fp32_b7[kernel_idx];
|
||||
case 8: return tfp_exl3_mgemm_kernel_fp32_b8[kernel_idx];
|
||||
default: TORCH_CHECK(false, "No kernel for GEMM shape");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (K)
|
||||
{
|
||||
case 1: return tfp_exl3_mgemm_kernel_fp16_b1[kernel_idx];
|
||||
case 2: return tfp_exl3_mgemm_kernel_fp16_b2[kernel_idx];
|
||||
case 3: return tfp_exl3_mgemm_kernel_fp16_b3[kernel_idx];
|
||||
case 4: return tfp_exl3_mgemm_kernel_fp16_b4[kernel_idx];
|
||||
case 5: return tfp_exl3_mgemm_kernel_fp16_b5[kernel_idx];
|
||||
case 6: return tfp_exl3_mgemm_kernel_fp16_b6[kernel_idx];
|
||||
case 7: return tfp_exl3_mgemm_kernel_fp16_b7[kernel_idx];
|
||||
case 8: return tfp_exl3_mgemm_kernel_fp16_b8[kernel_idx];
|
||||
default: TORCH_CHECK(false, "No kernel for GEMM shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TResult f_tr;
|
||||
|
||||
TResult* select_exl3_gemm_mgemm_kernel_new
|
||||
(
|
||||
int cc,
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int force_num_sms,
|
||||
int cb
|
||||
)
|
||||
{
|
||||
// Force parameters for tuning/benchmarking
|
||||
if (force_shape_idx > 0)
|
||||
{
|
||||
TORCH_CHECK(force_num_sms, "Must supply force_shape_idx and force_num_sms together");
|
||||
f_tr.kernel = get_gemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
|
||||
f_tr.mkernel = get_mgemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
|
||||
f_tr.shape_idx = force_shape_idx;
|
||||
f_tr.num_sms = force_num_sms;
|
||||
f_tr.block_dim = exl3_gemm_blockdim[force_shape_idx];
|
||||
return &f_tr;
|
||||
};
|
||||
TORCH_CHECK(!force_num_sms, "Must supply force_shape_idx and force_num_sms together.");
|
||||
|
||||
// Cache parameters
|
||||
uint64_t key = (((uint64_t) size_k) << 40) |
|
||||
(((uint64_t) size_n) << 16) |
|
||||
(((uint64_t) cc) << 8) |
|
||||
(((uint64_t) K) << 4) |
|
||||
(c_fp32 ? 0x01ull : 0x00ull);
|
||||
|
||||
auto lookup = _tuning_cache.find(key);
|
||||
if (lookup == _tuning_cache.end())
|
||||
{
|
||||
// Find closest kernel in map
|
||||
bool mod512 = (size_n % 512 == 0);
|
||||
bool mod256 = (size_n % 256 == 0);
|
||||
bool mod128 = (size_n % 128 == 0);
|
||||
TORCH_CHECK(mod128, "size_n must be a multiple of 128");
|
||||
TSample* cand = mod512 ? samples_512 : (mod256 ? samples_256 : samples_128);
|
||||
TSample* best = nullptr;
|
||||
int64_t best_dist = 1ll<<62;
|
||||
|
||||
for (; cand->K; cand++)
|
||||
{
|
||||
if (cand->K != K) continue;
|
||||
if (cand->cc != cc) continue;
|
||||
|
||||
int64_t distk = (int64_t) (size_k - cand->k);
|
||||
int64_t distn = (int64_t) (size_n - cand->n);
|
||||
int64_t dist = distk * distk + distn * distn;
|
||||
if (dist < best_dist) { best_dist = dist; best = cand; }
|
||||
}
|
||||
TORCH_CHECK(best, "Failed to find valid kernel for shape");
|
||||
|
||||
// Avoid empty blocks
|
||||
int tilesize_k = exl3_gemm_tilesize_k[best->shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n[best->shape_idx];
|
||||
int max_slices = size_k / tilesize_k * size_n / tilesize_n;
|
||||
int num_sms = MAX(MIN(max_slices, best->num_sms), 1);
|
||||
|
||||
// Results
|
||||
TResult tr = {
|
||||
get_gemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
|
||||
get_mgemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
|
||||
best->shape_idx,
|
||||
num_sms,
|
||||
exl3_gemm_blockdim[best->shape_idx]
|
||||
};
|
||||
|
||||
_tuning_cache[key] = tr;
|
||||
}
|
||||
|
||||
lookup = _tuning_cache.find(key);
|
||||
return &(lookup->second);
|
||||
}
|
||||
@@ -129,7 +129,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int bits,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int* out_block_dim,
|
||||
@@ -138,4 +138,48 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
int cb,
|
||||
int bszm_in,
|
||||
int bszm_out
|
||||
);
|
||||
);
|
||||
|
||||
struct TSample {
|
||||
int cc;
|
||||
int K;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
};
|
||||
|
||||
struct TMSample {
|
||||
int cc;
|
||||
int K;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
int bszm_in;
|
||||
int bszm_out;
|
||||
};
|
||||
|
||||
struct TResult
|
||||
{
|
||||
fp_exl3_gemm_kernel kernel;
|
||||
fp_exl3_mgemm_kernel mkernel;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
int block_dim;
|
||||
};
|
||||
|
||||
TResult* select_exl3_gemm_mgemm_kernel_new
|
||||
(
|
||||
int cc,
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int force_num_sms,
|
||||
int cb
|
||||
);
|
||||
|
||||
13118
exllamav3/exllamav3_ext/quant/exl3_kernel_map_samples.cuh
Normal file
13118
exllamav3/exllamav3_ext/quant/exl3_kernel_map_samples.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -379,7 +379,8 @@ class Attention(Module):
|
||||
self.multi_kv.mcg_mult,
|
||||
self.multi_kv.mul1_mult,
|
||||
-1,
|
||||
-1
|
||||
-1,
|
||||
0
|
||||
)
|
||||
k = kv[0].view(bsz, q_len, self.num_kv_heads * self.head_dim)
|
||||
v = kv[1].view(bsz, q_len, self.num_kv_heads * self.head_dim)
|
||||
|
||||
@@ -590,6 +590,7 @@ class BlockSparseMLP(Module):
|
||||
self.multi_gate.mul1_mult,
|
||||
mine,
|
||||
maxe,
|
||||
0
|
||||
)
|
||||
|
||||
# Up
|
||||
@@ -608,6 +609,7 @@ class BlockSparseMLP(Module):
|
||||
self.multi_up.mul1_mult,
|
||||
mine,
|
||||
maxe,
|
||||
0
|
||||
)
|
||||
|
||||
# Activation
|
||||
@@ -629,6 +631,7 @@ class BlockSparseMLP(Module):
|
||||
self.multi_down.mul1_mult,
|
||||
mine,
|
||||
maxe,
|
||||
0
|
||||
)
|
||||
|
||||
t = cfg.out_d[0]
|
||||
@@ -661,7 +664,8 @@ class BlockSparseMLP(Module):
|
||||
self.multi_gate.mcg_mult,
|
||||
self.multi_gate.mul1_mult,
|
||||
cfg.min_expert,
|
||||
cfg.max_expert
|
||||
cfg.max_expert,
|
||||
0
|
||||
)
|
||||
|
||||
# Up
|
||||
@@ -679,7 +683,8 @@ class BlockSparseMLP(Module):
|
||||
self.multi_up.mcg_mult,
|
||||
self.multi_up.mul1_mult,
|
||||
cfg.min_expert,
|
||||
cfg.max_expert
|
||||
cfg.max_expert,
|
||||
0
|
||||
)
|
||||
|
||||
# Activation
|
||||
@@ -700,7 +705,8 @@ class BlockSparseMLP(Module):
|
||||
self.multi_down.mcg_mult,
|
||||
self.multi_down.mul1_mult,
|
||||
cfg.min_expert,
|
||||
cfg.max_expert
|
||||
cfg.max_expert,
|
||||
0
|
||||
)
|
||||
|
||||
final_hidden_states = cfg.out_d[:1, ...].view(x.shape)
|
||||
|
||||
@@ -635,7 +635,8 @@ class GatedMLP(Module):
|
||||
self.multi_gu[s].mcg_mult,
|
||||
self.multi_gu[s].mul1_mult,
|
||||
-1,
|
||||
-1
|
||||
-1,
|
||||
0
|
||||
)
|
||||
g = gu[0].view(bsz, q_len, self.multi_gu[s].out_features)
|
||||
u = gu[1].view(bsz, q_len, self.multi_gu[s].out_features)
|
||||
|
||||
@@ -14,6 +14,11 @@ runs = 60
|
||||
shapes_m = [1, 4, 16]
|
||||
|
||||
shapes_kn = [
|
||||
(128, 4096),
|
||||
(4096, 128),
|
||||
(4096, 256),
|
||||
(4096, 512),
|
||||
(4096, 4096),
|
||||
(2048, 4096),
|
||||
(4096, 14336),
|
||||
(14336, 4096),
|
||||
@@ -77,7 +82,7 @@ def main():
|
||||
svh = [proto_svh.clone() for _ in range(num_buffers)]
|
||||
|
||||
# Get preferred kernel for current shape
|
||||
pref = ext.exl3_gemm(a[0], b[0], c[0], suh[0], a[0], svh[0], -1, mcg_mult, mul1_mult)
|
||||
pref = ext.exl3_gemm(a[0], b[0], c[0], suh[0], a[0], svh[0], -1, mcg_mult, mul1_mult, 0)
|
||||
|
||||
# Test all kernels
|
||||
kresults = []
|
||||
@@ -94,14 +99,14 @@ def main():
|
||||
# Warmup passes for good measure
|
||||
for i_ in range(10):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult)
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult, 0)
|
||||
|
||||
# Test
|
||||
dummy = c[0][0, 0].item()
|
||||
with Timer() as t:
|
||||
for i_ in range(runs):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult)
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult, 0)
|
||||
dummy = c[i][0, 0].item()
|
||||
|
||||
mean_time_ms = t.interval / runs * 1000
|
||||
|
||||
561
science/qgemm_pretune.py
Normal file
561
science/qgemm_pretune.py
Normal file
@@ -0,0 +1,561 @@
|
||||
import sys, os
|
||||
from collections import OrderedDict
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import torch
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
from exllamav3.util import Timer
|
||||
from exllamav3.util.memory import free_mem
|
||||
from tabulate import tabulate
|
||||
import numpy as np
|
||||
|
||||
num_warmup_passes = 10
|
||||
num_benchmark_iter_a = 20
|
||||
num_benchmark_iter_b = 40
|
||||
outlier_trim = 0.3
|
||||
assume_cache = 384 * 1024 ** 2
|
||||
|
||||
devices = [
|
||||
"cuda:1",
|
||||
"cuda:2",
|
||||
"cuda:3",
|
||||
]
|
||||
|
||||
shapes_m = [1]
|
||||
|
||||
shapes_k = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
8192,
|
||||
12288,
|
||||
14336,
|
||||
16384,
|
||||
24576,
|
||||
]
|
||||
|
||||
shapes_n = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
8192,
|
||||
12288,
|
||||
14336,
|
||||
16384,
|
||||
24576,
|
||||
51200,
|
||||
128000,
|
||||
]
|
||||
|
||||
shape_indices_128 = [1, 2]
|
||||
shape_indices_256 = [3]
|
||||
shape_indices_512 = [4]
|
||||
|
||||
Ks = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
|
||||
mgemm_bszm_io = [
|
||||
(1, 8),
|
||||
(8, 1),
|
||||
]
|
||||
|
||||
g_spin = 0
|
||||
|
||||
def get_abc(K, m, k, n, device):
|
||||
proto_a = torch.randn((m, k), dtype = torch.half, device = device)
|
||||
proto_b = torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device)
|
||||
proto_c = torch.zeros((m, n), dtype = torch.half, device = device)
|
||||
proto_suh = torch.randn((k,), dtype = torch.half, device = device)
|
||||
proto_svh = torch.randn((n,), dtype = torch.half, device = device)
|
||||
|
||||
# Create enough clones to cycle through to prevent L2 cache from skewing results
|
||||
proto_size = proto_a.numel() * 2 + proto_b.numel() * 2 + proto_c.numel() * 2
|
||||
num_buffers = max(assume_cache // proto_size + 1, 2)
|
||||
a = [proto_a.clone() for _ in range(num_buffers)]
|
||||
b = [proto_b.clone() for _ in range(num_buffers)]
|
||||
c = [proto_c.clone() for _ in range(num_buffers)]
|
||||
suh = [proto_suh.clone() for _ in range(num_buffers)]
|
||||
svh = [proto_svh.clone() for _ in range(num_buffers)]
|
||||
return a, b, c, suh, svh
|
||||
|
||||
|
||||
def get_abc_m(K, m, k, n, device, bszm_in, bszm_out):
|
||||
bszm = max(bszm_in, bszm_out)
|
||||
proto_a = torch.randn((bszm_in, m, k), dtype = torch.half, device = device)
|
||||
proto_b = [torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device) for _ in range(bszm)]
|
||||
proto_c = torch.zeros((bszm_out, m, n), dtype = torch.half, device = device)
|
||||
proto_suh = [torch.randn((k,), dtype = torch.half, device = device) for _ in range(bszm)]
|
||||
proto_svh = [torch.randn((n,), dtype = torch.half, device = device) for _ in range(bszm)]
|
||||
|
||||
# Create enough clones to cycle through to prevent L2 cache from skewing results
|
||||
proto_size = proto_a.numel() * 2 + sum(p.numel() for p in proto_b) * 2 + proto_c.numel() * 2
|
||||
num_buffers = max(assume_cache // proto_size + 1, 2)
|
||||
a = [proto_a.clone() for _ in range(num_buffers)]
|
||||
b = [[proto_b_.clone() for proto_b_ in proto_b] for _ in range(num_buffers)]
|
||||
c = [proto_c.clone() for _ in range(num_buffers)]
|
||||
suh = [[proto_suh_.clone() for proto_suh_ in proto_suh] for _ in range(num_buffers)]
|
||||
svh = [[proto_svh_.clone() for proto_svh_ in proto_svh] for _ in range(num_buffers)]
|
||||
trellis = b
|
||||
ptrs_suh = [torch.tensor([suh__.data_ptr() for suh__ in suh_], dtype = torch.long, device = device) for suh_ in suh]
|
||||
ptrs_svh = [torch.tensor([svh__.data_ptr() for svh__ in svh_], dtype = torch.long, device = device) for svh_ in svh]
|
||||
ptrs_trellis = [torch.tensor([trellis__.data_ptr() for trellis__ in trellis_], dtype = torch.long, device = device) for trellis_ in trellis]
|
||||
return a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis
|
||||
|
||||
|
||||
def warmup(a, b, c, suh, svh, shape_idx):
|
||||
global g_spin
|
||||
num_buffers = len(a)
|
||||
for _ in range(num_warmup_passes):
|
||||
i = g_spin % num_buffers
|
||||
g_spin += 1
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, 64)
|
||||
|
||||
|
||||
def warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis):
|
||||
num_buffers = len(a)
|
||||
num_exp = ptrs_trellis[0].shape[0]
|
||||
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
|
||||
K = b[0][0].shape[-1] // 16
|
||||
for i_ in range(num_warmup_passes):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_mgemm(
|
||||
a[i],
|
||||
ptrs_trellis[i],
|
||||
c[i],
|
||||
ptrs_suh[i],
|
||||
a[i],
|
||||
ptrs_svh[i],
|
||||
m_indices,
|
||||
None,
|
||||
K,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
0
|
||||
)
|
||||
|
||||
|
||||
def benchmark(a, b, c, suh, svh, shape_idx, num_iter, num_sms):
|
||||
num_buffers = len(a)
|
||||
dummy = c[0][0, 0].item()
|
||||
with Timer() as t:
|
||||
for i_ in range(num_iter):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
|
||||
dummy = c[i][0, 0].item()
|
||||
mean_time_ms = t.interval / num_iter * 1000
|
||||
return mean_time_ms
|
||||
|
||||
|
||||
def benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_iter, num_sms, trim = outlier_trim, stream = None):
|
||||
global g_spin
|
||||
device = a[0].device
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream(device)
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Precreate events to reduce overhead jitter
|
||||
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
|
||||
# Timed loop
|
||||
for it in range(num_iter):
|
||||
i = g_spin % len(a)
|
||||
g_spin += 1
|
||||
starts[it].record(stream)
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
|
||||
stops[it].record(stream)
|
||||
|
||||
# Ensure all recorded events are complete before reading times
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Collect per-iteration latency in milliseconds (GPU time)
|
||||
per_ms = np.array([starts[it].elapsed_time(stops[it]) for it in range(num_iter)], dtype = np.float64)
|
||||
|
||||
# Robust stats
|
||||
median = float(np.median(per_ms))
|
||||
mean = float(per_ms.mean())
|
||||
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
|
||||
|
||||
# Simple symmetric trimming
|
||||
if 0.0 < trim < 0.5 and num_iter > 4:
|
||||
lo = np.quantile(per_ms, trim)
|
||||
hi = np.quantile(per_ms, 1.0 - trim)
|
||||
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
|
||||
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
|
||||
else:
|
||||
trimmed = per_ms
|
||||
trimmed_mean = mean
|
||||
|
||||
return {
|
||||
"per_launch_ms": per_ms, # numpy array of length num_iter
|
||||
"mean_ms": mean,
|
||||
"median_ms": median,
|
||||
"std_ms": std,
|
||||
"trimmed_mean_ms": trimmed_mean,
|
||||
"trim_bounds_ms": (
|
||||
float(lo) if 'lo' in locals() else None,
|
||||
float(hi) if 'hi' in locals() else None
|
||||
),
|
||||
"kept_count": int(trimmed.size),
|
||||
"total_count": int(num_iter),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_iter, num_sms, trim = outlier_trim, stream = None):
|
||||
device = a[0].device
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream(device)
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
num_exp = ptrs_trellis[0].shape[0]
|
||||
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
|
||||
K = b[0][0].shape[-1] // 16
|
||||
|
||||
# Precreate events to reduce overhead jitter
|
||||
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
|
||||
# Timed loop
|
||||
for it in range(num_iter):
|
||||
i = it % len(a)
|
||||
starts[it].record(stream)
|
||||
for _ in range(2):
|
||||
ext.exl3_mgemm(
|
||||
a[i],
|
||||
ptrs_trellis[i],
|
||||
c[i],
|
||||
ptrs_suh[i],
|
||||
a[i],
|
||||
ptrs_svh[i],
|
||||
m_indices,
|
||||
None,
|
||||
K,
|
||||
shape_idx,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
num_sms
|
||||
)
|
||||
stops[it].record(stream)
|
||||
|
||||
# Ensure all recorded events are complete before reading times
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Collect per-iteration latency in milliseconds (GPU time)
|
||||
per_ms = np.array([starts[it].elapsed_time(stops[it]) / 2 for it in range(num_iter)], dtype = np.float64)
|
||||
|
||||
# Robust stats
|
||||
median = float(np.median(per_ms))
|
||||
mean = float(per_ms.mean())
|
||||
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
|
||||
|
||||
# Simple symmetric trimming
|
||||
if 0.0 < trim < 0.5 and num_iter > 4:
|
||||
lo = np.quantile(per_ms, trim)
|
||||
hi = np.quantile(per_ms, 1.0 - trim)
|
||||
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
|
||||
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
|
||||
else:
|
||||
trimmed = per_ms
|
||||
trimmed_mean = mean
|
||||
|
||||
return {
|
||||
"per_launch_ms": per_ms, # numpy array of length num_iter
|
||||
"mean_ms": mean,
|
||||
"median_ms": median,
|
||||
"std_ms": std,
|
||||
"trimmed_mean_ms": trimmed_mean,
|
||||
"trim_bounds_ms": (
|
||||
float(lo) if 'lo' in locals() else None,
|
||||
float(hi) if 'hi' in locals() else None
|
||||
),
|
||||
"kept_count": int(trimmed.size),
|
||||
"total_count": int(num_iter),
|
||||
}
|
||||
|
||||
def get_floor(hist_sms):
|
||||
best = float("inf"), 0, 0
|
||||
for l, s, i in hist_sms:
|
||||
if l < best[0]:
|
||||
best = l, s, i
|
||||
for l, s, i in hist_sms:
|
||||
if l < best[0] * 1.008:
|
||||
return l, s, i
|
||||
return best
|
||||
|
||||
|
||||
def test_latencies(mod, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices):
|
||||
|
||||
g_hist = []
|
||||
|
||||
for shape_idx in shape_indices:
|
||||
|
||||
max_sms = ext.g_get_num_sms(device_idx)
|
||||
|
||||
warmup(a, b, c, suh, svh, shape_idx)
|
||||
|
||||
hist_sms = []
|
||||
for num_sms in range(0, max_sms + 1, 8):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
|
||||
num_sms_ = max(num_sms, 1)
|
||||
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, num_sms_)
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
hist_sms.append((mean, num_sms_, shape_idx))
|
||||
|
||||
_, best_sms, _ = get_floor(hist_sms)
|
||||
|
||||
sms_a = best_sms - 8
|
||||
sms_b = best_sms + 8
|
||||
if sms_a < 0:
|
||||
sms_a = 0
|
||||
sms_b = 16
|
||||
if sms_b > num_sms + 4:
|
||||
sms_b = num_sms + 4
|
||||
sms_a = sms_b - 16
|
||||
|
||||
warmup(a, b, c, suh, svh, shape_idx)
|
||||
|
||||
for num_sms in range(sms_a, sms_b + 1, 2):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
num_sms_ = min(max(num_sms, 1), max_sms)
|
||||
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, num_sms_)
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
g_hist.append((mean, num_sms_, shape_idx))
|
||||
|
||||
g_hist = sorted(g_hist, key = lambda t: t[1])
|
||||
best_g_lat, best_g_sms, best_g_idx = get_floor(g_hist)
|
||||
|
||||
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}")
|
||||
return best_g_lat, best_g_sms, best_g_idx
|
||||
|
||||
|
||||
def test_latencies_m(mod, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices, bszm_in, bszm_out):
|
||||
|
||||
best_g_lat = float("inf")
|
||||
best_g_sms = None
|
||||
best_g_idx = None
|
||||
|
||||
best_l_lat = float("inf")
|
||||
best_l_sms = None
|
||||
best_l_idx = None
|
||||
|
||||
for shape_idx in shape_indices:
|
||||
|
||||
# Skip incompatible shapes
|
||||
# if not ext.exl3_gemm_shape_compat(shape_idx, m, k, n, K):
|
||||
# continue
|
||||
|
||||
max_sms = ext.g_get_num_sms(device_idx)
|
||||
|
||||
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
|
||||
|
||||
best_sms_lat = float("inf")
|
||||
for num_sms in range(0, max_sms + 1, 8):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
|
||||
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_a, max(num_sms, 1))
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
if mean < best_sms_lat:
|
||||
best_sms_lat = mean
|
||||
best_sms = max(num_sms, 1)
|
||||
|
||||
sms_a = best_sms - 8
|
||||
sms_b = best_sms + 8
|
||||
if sms_a < 0:
|
||||
sms_a = 0
|
||||
sms_b = 16
|
||||
if sms_b > num_sms + 4:
|
||||
sms_b = num_sms + 4
|
||||
sms_a = sms_b - 16
|
||||
|
||||
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
|
||||
|
||||
for num_sms in range(sms_a, sms_b + 1, 2):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
if mean < best_g_lat:
|
||||
best_g_lat = mean
|
||||
best_g_sms = min(max(num_sms, 1), max_sms)
|
||||
best_g_idx = shape_idx
|
||||
if mean < best_l_lat:
|
||||
best_l_lat = mean
|
||||
best_l_sms = min(max(num_sms, 1), max_sms)
|
||||
best_l_idx = shape_idx
|
||||
|
||||
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_l_lat:8.5f} ms, num_sms {best_l_sms:3}, shape_idx {best_l_idx}, i/o {bszm_in}/{bszm_out}")
|
||||
|
||||
print()
|
||||
print(f" --> mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}, i/o {bszm_in}/{bszm_out}")
|
||||
print()
|
||||
return best_g_lat, best_g_sms, best_g_idx
|
||||
|
||||
|
||||
def tune_shape(K, m, k, n, device):
|
||||
|
||||
a, b, c, suh, svh = get_abc(K, m, k, n, device)
|
||||
device_idx = torch.device(device).index
|
||||
cc = ext.g_get_cc(device_idx)
|
||||
|
||||
res = {
|
||||
"K": K,
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"cc": cc,
|
||||
"bszm_in": 1,
|
||||
"bszm_out": 1,
|
||||
}
|
||||
res_128 = None
|
||||
res_256 = None
|
||||
res_512 = None
|
||||
|
||||
if True:
|
||||
lat_128, sms_128, idx_128 = test_latencies(128, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_128)
|
||||
res_128 = res.copy()
|
||||
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
|
||||
|
||||
if n % 256 == 0:
|
||||
lat_256, sms_256, idx_256 = test_latencies(256, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_256)
|
||||
if lat_128 < lat_256:
|
||||
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
|
||||
res_256 = res.copy()
|
||||
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
|
||||
|
||||
if n % 512 == 0:
|
||||
lat_512, sms_512, idx_512 = test_latencies(512, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_512)
|
||||
if lat_128 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
|
||||
if lat_256 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
|
||||
res_512 = res.copy()
|
||||
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
|
||||
|
||||
return res_128, res_256, res_512
|
||||
|
||||
|
||||
def tune_shape_m(K, m, k, n, device, bszm_in, bszm_out):
|
||||
|
||||
a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis = get_abc_m(K, m, k, n, device, bszm_in, bszm_out)
|
||||
device_idx = torch.device(device).index
|
||||
cc = ext.g_get_cc(device_idx)
|
||||
|
||||
res = {
|
||||
"K": K,
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"cc": cc,
|
||||
"bszm_in": 1,
|
||||
"bszm_out": 1,
|
||||
}
|
||||
res_128 = None
|
||||
res_256 = None
|
||||
res_512 = None
|
||||
|
||||
if True:
|
||||
lat_128, sms_128, idx_128 = test_latencies_m(128, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_128, bszm_in, bszm_out)
|
||||
res_128 = res.copy()
|
||||
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
|
||||
|
||||
if n % 256 == 0:
|
||||
lat_256, sms_256, idx_256 = test_latencies_m(256, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_256, bszm_in, bszm_out)
|
||||
if lat_128 < lat_256:
|
||||
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
|
||||
res_256 = res.copy()
|
||||
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
|
||||
|
||||
if n % 512 == 0:
|
||||
lat_512, sms_512, idx_512 = test_latencies_m(512, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_512, bszm_in, bszm_out)
|
||||
if lat_128 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
|
||||
if lat_256 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
|
||||
res_512 = res.copy()
|
||||
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
|
||||
|
||||
return res_128, res_256, res_512
|
||||
|
||||
|
||||
def tune_gemm():
|
||||
out_128 = "struct TSample samples_128[] =\n{\n"
|
||||
out_256 = "struct TSample samples_256[] =\n{\n"
|
||||
out_512 = "struct TSample samples_512[] =\n{\n"
|
||||
for device in devices:
|
||||
for m in shapes_m:
|
||||
for k in shapes_k:
|
||||
for n in shapes_n:
|
||||
for ki, K in enumerate(Ks):
|
||||
res_128, res_256, res_512 = tune_shape(K, m, k, n, device)
|
||||
r = res_128
|
||||
if r:
|
||||
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
r = res_256
|
||||
if r:
|
||||
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
r = res_512
|
||||
if r:
|
||||
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
print(out_128)
|
||||
print()
|
||||
print(out_256)
|
||||
print()
|
||||
print(out_512)
|
||||
print()
|
||||
|
||||
|
||||
def tune_mgemm():
|
||||
out_128 = "struct TMSample msamples_128[] =\n{\n"
|
||||
out_256 = "struct TMSample msamples_256[] =\n{\n"
|
||||
out_512 = "struct TMSample msamples_512[] =\n{\n"
|
||||
for device in devices:
|
||||
for m in shapes_m:
|
||||
for k in shapes_k:
|
||||
for n in shapes_n:
|
||||
for ki, K in enumerate(Ks):
|
||||
for (bszm_in, bszm_out) in mgemm_bszm_io:
|
||||
res_128, res_256, res_512 = tune_shape_m(K, m, k, n, device, bszm_in, bszm_out)
|
||||
r = res_128
|
||||
if r:
|
||||
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
r = res_256
|
||||
if r:
|
||||
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
r = res_512
|
||||
if r:
|
||||
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
print(out_128)
|
||||
print()
|
||||
print(out_256)
|
||||
print()
|
||||
print(out_512)
|
||||
print()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
tune_gemm()
|
||||
# tune_mgemm()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user