Rework GEMM kernel tuning

This commit is contained in:
turboderp
2025-10-05 01:28:06 +02:00
parent c3cae873c4
commit 4829ea43d9
16 changed files with 14037 additions and 77 deletions

View File

@@ -23,6 +23,7 @@
#include "quant/exl3_gemm.cuh"
#include "quant/exl3_kernel_map.cuh"
#include "quant/util.cuh"
#include "quant/exl3_devctx.cuh"
#include "generator/strings.h"
#include "generator/sampling_basic.cuh"
@@ -87,6 +88,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("exl3_gemm", &exl3_gemm, "exl3_gemm");
m.def("exl3_gemm_num_kernel_shapes", &exl3_gemm_num_kernel_shapes, "exl3_gemm_num_kernel_shapes");
m.def("exl3_gemm_shape_compat", &exl3_gemm_shape_compat, "exl3_gemm_shape_compat");
m.def("g_get_cc", &g_get_cc, "g_get_cc");
m.def("g_get_num_sms", &g_get_num_sms, "g_get_num_sms");
m.def("exl3_mgemm", &exl3_mgemm, "exl3_mgemm");
m.def("hgemm", &hgemm, "hgemm");
m.def("rope", &rope, "rope");

View File

@@ -85,7 +85,8 @@ void BC_BlockSparseMLP::run_bsz1
gate_mcg_mult,
gate_mul1_mult,
min_expert,
max_expert
max_expert,
0
);
exl3_mgemm(
@@ -102,7 +103,8 @@ void BC_BlockSparseMLP::run_bsz1
up_mcg_mult,
up_mul1_mult,
min_expert,
max_expert
max_expert,
0
);
if (act_silu)
@@ -124,7 +126,8 @@ void BC_BlockSparseMLP::run_bsz1
down_mcg_mult,
down_mul1_mult,
min_expert,
max_expert
max_expert,
0
);
if (shared_experts)

View File

@@ -31,12 +31,12 @@ void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y)
{
if (x.numel() == x.size(-1))
{
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult);
exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg_mult, mul1_mult, 0);
}
else
{
at::Tensor xh_ = at::empty_like(x);
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult);
exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg_mult, mul1_mult, 0);
}
if (bias) y.add_(bias.value());

View File

@@ -31,7 +31,8 @@ void BC_GatedMLP::run_bsz1
gu_mcg_mult,
gu_mul1_mult,
-1,
-1
-1,
0
);
at::Tensor g = gu.select(0, 0).unsqueeze(0);

View File

@@ -55,4 +55,14 @@ int* DevCtx::get_locks(int device)
cudaMemset(locks[device], 0, MAX_TILES_C * sizeof(int));
}
return (int*) locks[device];
}
int g_get_cc(int device)
{
return DevCtx::instance().get_cc(device);
}
int g_get_num_sms(int device)
{
return DevCtx::instance().get_num_sms(device);
}

View File

@@ -6,12 +6,13 @@
// Max allowable output size, in tiles. Used to allocate global lock buffer per device for sync across threadblocks
#define MAX_TILES_C (1024 * 1024)
// Treat hopper and blackwell as same arch for now
#define MAX_DEVICES 32
#define CC_OLD 1
#define CC_AMPERE 2
#define CC_ADA 3
#define CC_HOPPER 4
#define CC_BLACKWELL 5
#define CC_BLACKWELL 4
// Singleton to manage context for each device. Stores device attributes and a large-enough lock buffer per device
class DevCtx
@@ -32,4 +33,7 @@ private:
DevCtx() = default;
DevCtx(const DevCtx&) = delete;
DevCtx& operator=(const DevCtx&) = delete;
};
};
int g_get_cc(int device);
int g_get_num_sms(int device);

View File

@@ -12,11 +12,17 @@ namespace cg = cooperative_groups;
#include "exl3_devctx.cuh"
#include <set>
#define NEW_TUNE_GEMM
#define NEW_TUNE_MGEMM
int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
/*
EXL3 matmul, A @ B -> C
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
- C: empty row-major C tensor, shape (m, n), dtype float16 or float32, contiguous. Does not need to be zero-initialized
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -39,7 +45,8 @@ int exl3_gemm
const c10::optional<at::Tensor>& svh,
int force_shape_idx,
uint32_t mcg_mult,
uint32_t mul1_mult
uint32_t mul1_mult,
int force_num_sms
)
{
const at::cuda::OptionalCUDAGuard device_guard(A.device());
@@ -48,7 +55,7 @@ int exl3_gemm
TORCH_CHECK_DIM(B, 3);
TORCH_CHECK_SHAPES(A, -1, B, 0, 16);
TORCH_CHECK_SHAPES(C, -1, B, 1, 16);
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
// TORCH_CHECK_SHAPES(A, 0, C, 0, 1);
TORCH_CHECK_DTYPE(A, kHalf);
TORCH_CHECK_DTYPE(B, kShort);
bool c_fp32 = C.dtype() == at::kFloat;
@@ -59,26 +66,26 @@ int exl3_gemm
half* A_had_ptr = nullptr;
if (suh_ptr)
{
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
// TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1);
A_had_ptr = (half*) OPTPTR(A_had);
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
// TORCH_CHECK(A_had_ptr, "Must supply A_had with suh");
// TORCH_CHECK_SHAPES_FULL(A_had.value(), A);
}
// Get SV, optionally
const half* svh_ptr = (const half*) OPTPTR(svh);
// if (svh_ptr)
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
// if (svh_ptr)
// TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16);
// Device properties
int device;
cudaGetDevice(&device);
int num_sms = DevCtx::instance().get_num_sms(device);
int num_sms = force_num_sms ? force_num_sms : DevCtx::instance().get_num_sms(device);
int cc = DevCtx::instance().get_cc(device);
int* locks = DevCtx::instance().get_locks(device);
// Dispatch
int bits = B.size(2) / 16;
int K = B.size(2) / 16;
const half* A_ptr = (const half*) A.data_ptr();
const uint16_t* B_ptr = (const uint16_t*) B.data_ptr();
void* C_ptr = (void*) C.data_ptr();
@@ -96,21 +103,33 @@ int exl3_gemm
if (mcg_mult) { cb = 1; mult = mcg_mult; }
if (mul1_mult) { cb = 2; mult = mul1_mult; }
int selected_shape;
int block_dim;
fp_exl3_gemm_kernel kernel = select_exl3_gemm_kernel
(
cc, size_m, size_k, size_n, bits, c_fp32,
force_shape_idx, &block_dim, &selected_shape,
&num_sms, cb
);
if (!kernel) return 0;
int shape_idx;
fp_exl3_gemm_kernel kernel;
#ifndef NEW_TUNE_GEMM
kernel = select_exl3_gemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb
);
if (!kernel) return 0;
#else
TResult* tr = select_exl3_gemm_mgemm_kernel_new(cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
if (!tr) return 0;
num_sms = MIN(num_sms, tr->num_sms);
kernel = tr->kernel;
block_dim = tr->block_dim;
shape_idx = tr->shape_idx;
#endif
// Launch
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*)kernel);
kernel_attr_set[device].insert((void*) kernel);
cuda_check(cudaPeekAtLastError());
}
void* kernelArgs[] =
{
@@ -128,7 +147,7 @@ int exl3_gemm
};
cudaLaunchCooperativeKernel
(
(void*)kernel,
(void*) kernel,
num_sms,
block_dim,
kernelArgs,
@@ -136,14 +155,16 @@ int exl3_gemm
stream
);
cuda_check(cudaPeekAtLastError());
return selected_shape;
// return selected_shape;
return shape_idx;
}
/*
EXL3 multi matmul, A @ B -> C
- A: row-major A tensor, shape (m, k), dtype float16, contiguous
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*bits), dtype uint16
- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16
- C: empty row-major C tensor, shape (m, n), dtype float16 or float23, contiguous. Does not need to be zero-initialized
- suh: optional, packed input scales/flips, shape (k//16), dtype float16
- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A
@@ -169,7 +190,8 @@ int exl3_mgemm
uint32_t mcg_mult,
uint32_t mul1_mult,
int min_index,
int max_index
int max_index,
int force_num_sms
)
{
const at::cuda::OptionalCUDAGuard device_guard(A.device());
@@ -194,6 +216,7 @@ int exl3_mgemm
int bsz = A.size(1);
int bszm_in = A.size(0);
int bszm_out = C.size(0);
int bszm = MAX(bszm_in, bszm_out);
const long* indices_ptr = (const long*) OPTPTR(indices);
const half* weights_ptr = (const half*) OPTPTR(weights);
@@ -219,8 +242,8 @@ int exl3_mgemm
// Device properties
int device;
cudaGetDevice(&device);
int num_sms = DevCtx::instance().get_num_sms(device);
int total_sms = num_sms;
int total_sms = DevCtx::instance().get_num_sms(device);
int num_sms = force_num_sms ? force_num_sms : total_sms;
int cc = DevCtx::instance().get_cc(device);
int* locks = DevCtx::instance().get_locks(device);
@@ -239,25 +262,44 @@ int exl3_mgemm
if (mcg_mult) { cb = 1; mult = mcg_mult; }
if (mul1_mult) { cb = 2; mult = mul1_mult; }
int selected_shape;
int shape_idx;
int block_dim;
fp_exl3_mgemm_kernel kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &selected_shape,
&num_sms, cb, bszm_in, bszm_out
);
if (!kernel) return 0;
fp_exl3_mgemm_kernel kernel;
int concurrency;
#ifndef NEW_TUNE_MGEMM
kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb, bszm_in, bszm_out
);
if (!kernel) return 0;
concurrency = MIN(total_sms / num_sms, bszm_out);
#else
kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb, bszm_in, bszm_out
);
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
num_sms = tiles;
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
concurrency = MIN(total_sms / num_sms, bszm);
#endif
// Launch bigger grid if possible
int concurrency = MIN(total_sms / num_sms, bszm_out);
dim3 block_grid(num_sms, 1, concurrency);
// Launch
if (kernel_attr_set[device].find((void*)kernel) == kernel_attr_set[device].end())
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*)kernel);
kernel_attr_set[device].insert((void*) kernel);
}
void* kernelArgs[] =
{
@@ -279,10 +321,9 @@ int exl3_mgemm
(void*)& min_index,
(void*)& max_index
};
cudaLaunchCooperativeKernel
(
(void*)kernel,
(void*) kernel,
block_grid,
block_dim,
kernelArgs,
@@ -290,5 +331,5 @@ int exl3_mgemm
stream
);
cuda_check(cudaPeekAtLastError());
return selected_shape;
return shape_idx;
}

View File

@@ -12,7 +12,8 @@ int exl3_gemm
const c10::optional<at::Tensor>& svh,
int force_shape_idx,
uint32_t mcg_mult,
uint32_t mul1_mult
uint32_t mul1_mult,
int force_num_sms
);
int exl3_mgemm
@@ -30,5 +31,6 @@ int exl3_mgemm
uint32_t mcg_mult,
uint32_t mul1_mult,
int min_index,
int max_index
int max_index,
int force_num_sms
);

View File

@@ -8,6 +8,7 @@ namespace cg = cooperative_groups;
#include "../ptx.cuh"
#include <tuple>
#include <mutex>
#include <map>
#include "exl3_kernel_map.cuh"
#include "exl3_devctx.cuh"
#include "comp_units/exl3_comp_unit_1.cuh"
@@ -19,7 +20,10 @@ namespace cg = cooperative_groups;
#include "comp_units/exl3_comp_unit_7.cuh"
#include "comp_units/exl3_comp_unit_8.cuh"
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool multi, int bszm_in, int bszm_out)
#include "exl3_kernel_map_samples.cuh"
std::map<uint64_t, TResult> _tuning_cache = {};
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int K, bool multi, int bszm_in, int bszm_out)
{
bool mod_256 = (size_n % 256 == 0);
bool mod_512 = (size_n % 512 == 0);
@@ -31,18 +35,18 @@ int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool
{
case CC_OLD:
case CC_AMPERE:
if (mod_256 && bits <= 4)
if (mod_256 && K <= 4)
{
if (size_n <= 2048 || size_k <= 2048) return 2;
return 3;
}
if (mod_256 && size_n < 4096) return size_k > 8192 ? 3 : 2;
if (mod_512 && (size_n * size_k) > (4096 * 4096) && bits <= 6) return 4;
if (mod_512 && (size_n * size_k) > (4096 * 4096) && K <= 6) return 4;
if (mod_256) return 3;
return 2;
case CC_ADA:
if (mod_256 && bits <= 3)
if (mod_256 && K <= 3)
{
if (size_k <= 2048 && !multi) return 2;
if (size_n < 4096 && size_k <= 12288) return 2;
@@ -53,19 +57,19 @@ int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int bits, bool
if (mod_256) return 3;
return 2;
case CC_HOPPER:
// case CC_HOPPER:
case CC_BLACKWELL:
if ((bits == 4 || bits == 2) && !multi)
if ((K == 4 || K == 2) && !multi)
{
if (size_k <= 2048) return 1;
}
if (bits >= 7)
if (K >= 7)
{
if (mod_256 && size_n <= 8192) return size_k > 32768 ? 3 : 2;
if (mod_512 && size_n > 32768) return 4;
return 2;
}
if (mod_256 && size_n <= 4096) return size_k > 8192 && bits >= 3 ? 3 : 2;
if (mod_256 && size_n <= 4096) return size_k > 8192 && K >= 3 ? 3 : 2;
if (mod_512 && size_n > 16384) return 4;
if (mod_256) return 3;
return 2;
@@ -82,7 +86,7 @@ int exl3_gemm_tilesize_k[] = {EXL3_GEMM_TILESIZE_K};
int exl3_gemm_tilesize_n[] = {EXL3_GEMM_TILESIZE_N};
int exl3_gemm_blockdim[] = {EXL3_GEMM_BLOCKDIM};
bool exl3_gemm_shape_compat(int shape_idx, int size_m, int size_k, int size_n, int bits)
bool exl3_gemm_shape_compat(int shape_idx, int size_m, int size_k, int size_n, int K)
{
int tilesize_k = exl3_gemm_tilesize_k[shape_idx];
int tilesize_n = exl3_gemm_tilesize_n[shape_idx];
@@ -95,7 +99,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
int size_m,
int size_k,
int size_n,
int bits,
int K,
bool c_fp32,
int force_shape_idx,
int* out_block_dim,
@@ -104,7 +108,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
int cb
)
{
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, bits, false, 1, 1) : force_shape_idx;
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, K, false, 1, 1) : force_shape_idx;
TORCH_CHECK(shape_idx > 0, "exl3_gemm: no compatible kernel");
if (out_shape_idx) *out_shape_idx = shape_idx;
@@ -123,7 +127,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
if (c_fp32)
{
switch (bits)
switch (K)
{
case 1: return tfp_exl3_gemm_kernel_fp32_b1[kernel_idx];
case 2: return tfp_exl3_gemm_kernel_fp32_b2[kernel_idx];
@@ -138,7 +142,7 @@ fp_exl3_gemm_kernel select_exl3_gemm_kernel
}
else
{
switch (bits)
switch (K)
{
case 1: return tfp_exl3_gemm_kernel_fp16_b1[kernel_idx];
case 2: return tfp_exl3_gemm_kernel_fp16_b2[kernel_idx];
@@ -159,7 +163,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
int size_m,
int size_k,
int size_n,
int bits,
int K,
bool c_fp32,
int force_shape_idx,
int* out_block_dim,
@@ -170,7 +174,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
int bszm_out
)
{
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, bits, true, bszm_in, bszm_out) : force_shape_idx;
int shape_idx = force_shape_idx <= 0 ? select_gemm_shape(cc, size_m, size_k, size_n, K, true, bszm_in, bszm_out) : force_shape_idx;
TORCH_CHECK(shape_idx > 0, "exl3_mgemm: no compatible kernel");
if (out_shape_idx) *out_shape_idx = shape_idx;
if (out_block_dim) *out_block_dim = exl3_gemm_blockdim[shape_idx];
@@ -188,7 +192,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
if (c_fp32)
{
switch (bits)
switch (K)
{
case 1: return tfp_exl3_mgemm_kernel_fp32_b1[kernel_idx];
case 2: return tfp_exl3_mgemm_kernel_fp32_b2[kernel_idx];
@@ -203,7 +207,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
}
else
{
switch (bits)
switch (K)
{
case 1: return tfp_exl3_mgemm_kernel_fp16_b1[kernel_idx];
case 2: return tfp_exl3_mgemm_kernel_fp16_b2[kernel_idx];
@@ -216,4 +220,160 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
default: TORCH_CHECK(false, "No kernel for GEMM shape");
}
}
}
fp_exl3_gemm_kernel get_gemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb)
{
int kernel_idx = shape_idx + (EXL3_GEMM_NUM_SHAPES + 1) * cb;
if (c_fp32)
{
switch (K)
{
case 1: return tfp_exl3_gemm_kernel_fp32_b1[kernel_idx];
case 2: return tfp_exl3_gemm_kernel_fp32_b2[kernel_idx];
case 3: return tfp_exl3_gemm_kernel_fp32_b3[kernel_idx];
case 4: return tfp_exl3_gemm_kernel_fp32_b4[kernel_idx];
case 5: return tfp_exl3_gemm_kernel_fp32_b5[kernel_idx];
case 6: return tfp_exl3_gemm_kernel_fp32_b6[kernel_idx];
case 7: return tfp_exl3_gemm_kernel_fp32_b7[kernel_idx];
case 8: return tfp_exl3_gemm_kernel_fp32_b8[kernel_idx];
default: TORCH_CHECK(false, "No kernel for GEMM shape");
}
}
else
{
switch (K)
{
case 1: return tfp_exl3_gemm_kernel_fp16_b1[kernel_idx];
case 2: return tfp_exl3_gemm_kernel_fp16_b2[kernel_idx];
case 3: return tfp_exl3_gemm_kernel_fp16_b3[kernel_idx];
case 4: return tfp_exl3_gemm_kernel_fp16_b4[kernel_idx];
case 5: return tfp_exl3_gemm_kernel_fp16_b5[kernel_idx];
case 6: return tfp_exl3_gemm_kernel_fp16_b6[kernel_idx];
case 7: return tfp_exl3_gemm_kernel_fp16_b7[kernel_idx];
case 8: return tfp_exl3_gemm_kernel_fp16_b8[kernel_idx];
default: TORCH_CHECK(false, "No kernel for GEMM shape");
}
}
}
fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb)
{
int kernel_idx = shape_idx + (EXL3_GEMM_NUM_SHAPES + 1) * cb;
if (c_fp32)
{
switch (K)
{
case 1: return tfp_exl3_mgemm_kernel_fp32_b1[kernel_idx];
case 2: return tfp_exl3_mgemm_kernel_fp32_b2[kernel_idx];
case 3: return tfp_exl3_mgemm_kernel_fp32_b3[kernel_idx];
case 4: return tfp_exl3_mgemm_kernel_fp32_b4[kernel_idx];
case 5: return tfp_exl3_mgemm_kernel_fp32_b5[kernel_idx];
case 6: return tfp_exl3_mgemm_kernel_fp32_b6[kernel_idx];
case 7: return tfp_exl3_mgemm_kernel_fp32_b7[kernel_idx];
case 8: return tfp_exl3_mgemm_kernel_fp32_b8[kernel_idx];
default: TORCH_CHECK(false, "No kernel for GEMM shape");
}
}
else
{
switch (K)
{
case 1: return tfp_exl3_mgemm_kernel_fp16_b1[kernel_idx];
case 2: return tfp_exl3_mgemm_kernel_fp16_b2[kernel_idx];
case 3: return tfp_exl3_mgemm_kernel_fp16_b3[kernel_idx];
case 4: return tfp_exl3_mgemm_kernel_fp16_b4[kernel_idx];
case 5: return tfp_exl3_mgemm_kernel_fp16_b5[kernel_idx];
case 6: return tfp_exl3_mgemm_kernel_fp16_b6[kernel_idx];
case 7: return tfp_exl3_mgemm_kernel_fp16_b7[kernel_idx];
case 8: return tfp_exl3_mgemm_kernel_fp16_b8[kernel_idx];
default: TORCH_CHECK(false, "No kernel for GEMM shape");
}
}
}
TResult f_tr;
TResult* select_exl3_gemm_mgemm_kernel_new
(
int cc,
int size_m,
int size_k,
int size_n,
int K,
bool c_fp32,
int force_shape_idx,
int force_num_sms,
int cb
)
{
// Force parameters for tuning/benchmarking
if (force_shape_idx > 0)
{
TORCH_CHECK(force_num_sms, "Must supply force_shape_idx and force_num_sms together");
f_tr.kernel = get_gemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
f_tr.mkernel = get_mgemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
f_tr.shape_idx = force_shape_idx;
f_tr.num_sms = force_num_sms;
f_tr.block_dim = exl3_gemm_blockdim[force_shape_idx];
return &f_tr;
};
TORCH_CHECK(!force_num_sms, "Must supply force_shape_idx and force_num_sms together.");
// Cache parameters
uint64_t key = (((uint64_t) size_k) << 40) |
(((uint64_t) size_n) << 16) |
(((uint64_t) cc) << 8) |
(((uint64_t) K) << 4) |
(c_fp32 ? 0x01ull : 0x00ull);
auto lookup = _tuning_cache.find(key);
if (lookup == _tuning_cache.end())
{
// Find closest kernel in map
bool mod512 = (size_n % 512 == 0);
bool mod256 = (size_n % 256 == 0);
bool mod128 = (size_n % 128 == 0);
TORCH_CHECK(mod128, "size_n must be a multiple of 128");
TSample* cand = mod512 ? samples_512 : (mod256 ? samples_256 : samples_128);
TSample* best = nullptr;
int64_t best_dist = 1ll<<62;
for (; cand->K; cand++)
{
if (cand->K != K) continue;
if (cand->cc != cc) continue;
int64_t distk = (int64_t) (size_k - cand->k);
int64_t distn = (int64_t) (size_n - cand->n);
int64_t dist = distk * distk + distn * distn;
if (dist < best_dist) { best_dist = dist; best = cand; }
}
TORCH_CHECK(best, "Failed to find valid kernel for shape");
// Avoid empty blocks
int tilesize_k = exl3_gemm_tilesize_k[best->shape_idx];
int tilesize_n = exl3_gemm_tilesize_n[best->shape_idx];
int max_slices = size_k / tilesize_k * size_n / tilesize_n;
int num_sms = MAX(MIN(max_slices, best->num_sms), 1);
// Results
TResult tr = {
get_gemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
get_mgemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
best->shape_idx,
num_sms,
exl3_gemm_blockdim[best->shape_idx]
};
_tuning_cache[key] = tr;
}
lookup = _tuning_cache.find(key);
return &(lookup->second);
}

View File

@@ -129,7 +129,7 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
int size_m,
int size_k,
int size_n,
int bits,
int K,
bool c_fp32,
int force_shape_idx,
int* out_block_dim,
@@ -138,4 +138,48 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
int cb,
int bszm_in,
int bszm_out
);
);
struct TSample {
int cc;
int K;
int m;
int k;
int n;
int shape_idx;
int num_sms;
};
struct TMSample {
int cc;
int K;
int m;
int k;
int n;
int shape_idx;
int num_sms;
int bszm_in;
int bszm_out;
};
struct TResult
{
fp_exl3_gemm_kernel kernel;
fp_exl3_mgemm_kernel mkernel;
int shape_idx;
int num_sms;
int block_dim;
};
TResult* select_exl3_gemm_mgemm_kernel_new
(
int cc,
int size_m,
int size_k,
int size_n,
int K,
bool c_fp32,
int force_shape_idx,
int force_num_sms,
int cb
);

File diff suppressed because it is too large Load Diff

View File

@@ -379,7 +379,8 @@ class Attention(Module):
self.multi_kv.mcg_mult,
self.multi_kv.mul1_mult,
-1,
-1
-1,
0
)
k = kv[0].view(bsz, q_len, self.num_kv_heads * self.head_dim)
v = kv[1].view(bsz, q_len, self.num_kv_heads * self.head_dim)

View File

@@ -590,6 +590,7 @@ class BlockSparseMLP(Module):
self.multi_gate.mul1_mult,
mine,
maxe,
0
)
# Up
@@ -608,6 +609,7 @@ class BlockSparseMLP(Module):
self.multi_up.mul1_mult,
mine,
maxe,
0
)
# Activation
@@ -629,6 +631,7 @@ class BlockSparseMLP(Module):
self.multi_down.mul1_mult,
mine,
maxe,
0
)
t = cfg.out_d[0]
@@ -661,7 +664,8 @@ class BlockSparseMLP(Module):
self.multi_gate.mcg_mult,
self.multi_gate.mul1_mult,
cfg.min_expert,
cfg.max_expert
cfg.max_expert,
0
)
# Up
@@ -679,7 +683,8 @@ class BlockSparseMLP(Module):
self.multi_up.mcg_mult,
self.multi_up.mul1_mult,
cfg.min_expert,
cfg.max_expert
cfg.max_expert,
0
)
# Activation
@@ -700,7 +705,8 @@ class BlockSparseMLP(Module):
self.multi_down.mcg_mult,
self.multi_down.mul1_mult,
cfg.min_expert,
cfg.max_expert
cfg.max_expert,
0
)
final_hidden_states = cfg.out_d[:1, ...].view(x.shape)

View File

@@ -635,7 +635,8 @@ class GatedMLP(Module):
self.multi_gu[s].mcg_mult,
self.multi_gu[s].mul1_mult,
-1,
-1
-1,
0
)
g = gu[0].view(bsz, q_len, self.multi_gu[s].out_features)
u = gu[1].view(bsz, q_len, self.multi_gu[s].out_features)

View File

@@ -14,6 +14,11 @@ runs = 60
shapes_m = [1, 4, 16]
shapes_kn = [
(128, 4096),
(4096, 128),
(4096, 256),
(4096, 512),
(4096, 4096),
(2048, 4096),
(4096, 14336),
(14336, 4096),
@@ -77,7 +82,7 @@ def main():
svh = [proto_svh.clone() for _ in range(num_buffers)]
# Get preferred kernel for current shape
pref = ext.exl3_gemm(a[0], b[0], c[0], suh[0], a[0], svh[0], -1, mcg_mult, mul1_mult)
pref = ext.exl3_gemm(a[0], b[0], c[0], suh[0], a[0], svh[0], -1, mcg_mult, mul1_mult, 0)
# Test all kernels
kresults = []
@@ -94,14 +99,14 @@ def main():
# Warmup passes for good measure
for i_ in range(10):
i = i_ % num_buffers
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult)
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult, 0)
# Test
dummy = c[0][0, 0].item()
with Timer() as t:
for i_ in range(runs):
i = i_ % num_buffers
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult)
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], kernel, mcg_mult, mul1_mult, 0)
dummy = c[i][0, 0].item()
mean_time_ms = t.interval / runs * 1000

561
science/qgemm_pretune.py Normal file
View File

@@ -0,0 +1,561 @@
import sys, os
from collections import OrderedDict
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from exllamav3.ext import exllamav3_ext as ext
from exllamav3.util import Timer
from exllamav3.util.memory import free_mem
from tabulate import tabulate
import numpy as np
num_warmup_passes = 10
num_benchmark_iter_a = 20
num_benchmark_iter_b = 40
outlier_trim = 0.3
assume_cache = 384 * 1024 ** 2
devices = [
"cuda:1",
"cuda:2",
"cuda:3",
]
shapes_m = [1]
shapes_k = [
128,
256,
512,
1024,
2048,
3072,
4096,
5120,
8192,
12288,
14336,
16384,
24576,
]
shapes_n = [
128,
256,
512,
1024,
2048,
3072,
4096,
5120,
8192,
12288,
14336,
16384,
24576,
51200,
128000,
]
shape_indices_128 = [1, 2]
shape_indices_256 = [3]
shape_indices_512 = [4]
Ks = [1, 2, 3, 4, 5, 6, 7, 8]
mgemm_bszm_io = [
(1, 8),
(8, 1),
]
g_spin = 0
def get_abc(K, m, k, n, device):
proto_a = torch.randn((m, k), dtype = torch.half, device = device)
proto_b = torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device)
proto_c = torch.zeros((m, n), dtype = torch.half, device = device)
proto_suh = torch.randn((k,), dtype = torch.half, device = device)
proto_svh = torch.randn((n,), dtype = torch.half, device = device)
# Create enough clones to cycle through to prevent L2 cache from skewing results
proto_size = proto_a.numel() * 2 + proto_b.numel() * 2 + proto_c.numel() * 2
num_buffers = max(assume_cache // proto_size + 1, 2)
a = [proto_a.clone() for _ in range(num_buffers)]
b = [proto_b.clone() for _ in range(num_buffers)]
c = [proto_c.clone() for _ in range(num_buffers)]
suh = [proto_suh.clone() for _ in range(num_buffers)]
svh = [proto_svh.clone() for _ in range(num_buffers)]
return a, b, c, suh, svh
def get_abc_m(K, m, k, n, device, bszm_in, bszm_out):
bszm = max(bszm_in, bszm_out)
proto_a = torch.randn((bszm_in, m, k), dtype = torch.half, device = device)
proto_b = [torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device) for _ in range(bszm)]
proto_c = torch.zeros((bszm_out, m, n), dtype = torch.half, device = device)
proto_suh = [torch.randn((k,), dtype = torch.half, device = device) for _ in range(bszm)]
proto_svh = [torch.randn((n,), dtype = torch.half, device = device) for _ in range(bszm)]
# Create enough clones to cycle through to prevent L2 cache from skewing results
proto_size = proto_a.numel() * 2 + sum(p.numel() for p in proto_b) * 2 + proto_c.numel() * 2
num_buffers = max(assume_cache // proto_size + 1, 2)
a = [proto_a.clone() for _ in range(num_buffers)]
b = [[proto_b_.clone() for proto_b_ in proto_b] for _ in range(num_buffers)]
c = [proto_c.clone() for _ in range(num_buffers)]
suh = [[proto_suh_.clone() for proto_suh_ in proto_suh] for _ in range(num_buffers)]
svh = [[proto_svh_.clone() for proto_svh_ in proto_svh] for _ in range(num_buffers)]
trellis = b
ptrs_suh = [torch.tensor([suh__.data_ptr() for suh__ in suh_], dtype = torch.long, device = device) for suh_ in suh]
ptrs_svh = [torch.tensor([svh__.data_ptr() for svh__ in svh_], dtype = torch.long, device = device) for svh_ in svh]
ptrs_trellis = [torch.tensor([trellis__.data_ptr() for trellis__ in trellis_], dtype = torch.long, device = device) for trellis_ in trellis]
return a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis
def warmup(a, b, c, suh, svh, shape_idx):
global g_spin
num_buffers = len(a)
for _ in range(num_warmup_passes):
i = g_spin % num_buffers
g_spin += 1
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, 64)
def warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis):
num_buffers = len(a)
num_exp = ptrs_trellis[0].shape[0]
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
K = b[0][0].shape[-1] // 16
for i_ in range(num_warmup_passes):
i = i_ % num_buffers
ext.exl3_mgemm(
a[i],
ptrs_trellis[i],
c[i],
ptrs_suh[i],
a[i],
ptrs_svh[i],
m_indices,
None,
K,
-1,
0,
0,
-1,
-1,
0
)
def benchmark(a, b, c, suh, svh, shape_idx, num_iter, num_sms):
num_buffers = len(a)
dummy = c[0][0, 0].item()
with Timer() as t:
for i_ in range(num_iter):
i = i_ % num_buffers
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
dummy = c[i][0, 0].item()
mean_time_ms = t.interval / num_iter * 1000
return mean_time_ms
def benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_iter, num_sms, trim = outlier_trim, stream = None):
global g_spin
device = a[0].device
if stream is None:
stream = torch.cuda.current_stream(device)
torch.cuda.synchronize(device)
# Precreate events to reduce overhead jitter
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
# Timed loop
for it in range(num_iter):
i = g_spin % len(a)
g_spin += 1
starts[it].record(stream)
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
stops[it].record(stream)
# Ensure all recorded events are complete before reading times
torch.cuda.synchronize(device)
# Collect per-iteration latency in milliseconds (GPU time)
per_ms = np.array([starts[it].elapsed_time(stops[it]) for it in range(num_iter)], dtype = np.float64)
# Robust stats
median = float(np.median(per_ms))
mean = float(per_ms.mean())
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
# Simple symmetric trimming
if 0.0 < trim < 0.5 and num_iter > 4:
lo = np.quantile(per_ms, trim)
hi = np.quantile(per_ms, 1.0 - trim)
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
else:
trimmed = per_ms
trimmed_mean = mean
return {
"per_launch_ms": per_ms, # numpy array of length num_iter
"mean_ms": mean,
"median_ms": median,
"std_ms": std,
"trimmed_mean_ms": trimmed_mean,
"trim_bounds_ms": (
float(lo) if 'lo' in locals() else None,
float(hi) if 'hi' in locals() else None
),
"kept_count": int(trimmed.size),
"total_count": int(num_iter),
}
def benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_iter, num_sms, trim = outlier_trim, stream = None):
device = a[0].device
if stream is None:
stream = torch.cuda.current_stream(device)
torch.cuda.synchronize(device)
num_exp = ptrs_trellis[0].shape[0]
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
K = b[0][0].shape[-1] // 16
# Precreate events to reduce overhead jitter
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
# Timed loop
for it in range(num_iter):
i = it % len(a)
starts[it].record(stream)
for _ in range(2):
ext.exl3_mgemm(
a[i],
ptrs_trellis[i],
c[i],
ptrs_suh[i],
a[i],
ptrs_svh[i],
m_indices,
None,
K,
shape_idx,
0,
0,
-1,
-1,
num_sms
)
stops[it].record(stream)
# Ensure all recorded events are complete before reading times
torch.cuda.synchronize(device)
# Collect per-iteration latency in milliseconds (GPU time)
per_ms = np.array([starts[it].elapsed_time(stops[it]) / 2 for it in range(num_iter)], dtype = np.float64)
# Robust stats
median = float(np.median(per_ms))
mean = float(per_ms.mean())
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
# Simple symmetric trimming
if 0.0 < trim < 0.5 and num_iter > 4:
lo = np.quantile(per_ms, trim)
hi = np.quantile(per_ms, 1.0 - trim)
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
else:
trimmed = per_ms
trimmed_mean = mean
return {
"per_launch_ms": per_ms, # numpy array of length num_iter
"mean_ms": mean,
"median_ms": median,
"std_ms": std,
"trimmed_mean_ms": trimmed_mean,
"trim_bounds_ms": (
float(lo) if 'lo' in locals() else None,
float(hi) if 'hi' in locals() else None
),
"kept_count": int(trimmed.size),
"total_count": int(num_iter),
}
def get_floor(hist_sms):
best = float("inf"), 0, 0
for l, s, i in hist_sms:
if l < best[0]:
best = l, s, i
for l, s, i in hist_sms:
if l < best[0] * 1.008:
return l, s, i
return best
def test_latencies(mod, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices):
g_hist = []
for shape_idx in shape_indices:
max_sms = ext.g_get_num_sms(device_idx)
warmup(a, b, c, suh, svh, shape_idx)
hist_sms = []
for num_sms in range(0, max_sms + 1, 8):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
num_sms_ = max(num_sms, 1)
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, num_sms_)
mean = stats["trimmed_mean_ms"]
hist_sms.append((mean, num_sms_, shape_idx))
_, best_sms, _ = get_floor(hist_sms)
sms_a = best_sms - 8
sms_b = best_sms + 8
if sms_a < 0:
sms_a = 0
sms_b = 16
if sms_b > num_sms + 4:
sms_b = num_sms + 4
sms_a = sms_b - 16
warmup(a, b, c, suh, svh, shape_idx)
for num_sms in range(sms_a, sms_b + 1, 2):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
num_sms_ = min(max(num_sms, 1), max_sms)
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, num_sms_)
mean = stats["trimmed_mean_ms"]
g_hist.append((mean, num_sms_, shape_idx))
g_hist = sorted(g_hist, key = lambda t: t[1])
best_g_lat, best_g_sms, best_g_idx = get_floor(g_hist)
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}")
return best_g_lat, best_g_sms, best_g_idx
def test_latencies_m(mod, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices, bszm_in, bszm_out):
best_g_lat = float("inf")
best_g_sms = None
best_g_idx = None
best_l_lat = float("inf")
best_l_sms = None
best_l_idx = None
for shape_idx in shape_indices:
# Skip incompatible shapes
# if not ext.exl3_gemm_shape_compat(shape_idx, m, k, n, K):
# continue
max_sms = ext.g_get_num_sms(device_idx)
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
best_sms_lat = float("inf")
for num_sms in range(0, max_sms + 1, 8):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_a, max(num_sms, 1))
mean = stats["trimmed_mean_ms"]
if mean < best_sms_lat:
best_sms_lat = mean
best_sms = max(num_sms, 1)
sms_a = best_sms - 8
sms_b = best_sms + 8
if sms_a < 0:
sms_a = 0
sms_b = 16
if sms_b > num_sms + 4:
sms_b = num_sms + 4
sms_a = sms_b - 16
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
for num_sms in range(sms_a, sms_b + 1, 2):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
mean = stats["trimmed_mean_ms"]
if mean < best_g_lat:
best_g_lat = mean
best_g_sms = min(max(num_sms, 1), max_sms)
best_g_idx = shape_idx
if mean < best_l_lat:
best_l_lat = mean
best_l_sms = min(max(num_sms, 1), max_sms)
best_l_idx = shape_idx
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_l_lat:8.5f} ms, num_sms {best_l_sms:3}, shape_idx {best_l_idx}, i/o {bszm_in}/{bszm_out}")
print()
print(f" --> mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}, i/o {bszm_in}/{bszm_out}")
print()
return best_g_lat, best_g_sms, best_g_idx
def tune_shape(K, m, k, n, device):
a, b, c, suh, svh = get_abc(K, m, k, n, device)
device_idx = torch.device(device).index
cc = ext.g_get_cc(device_idx)
res = {
"K": K,
"m": m,
"k": k,
"n": n,
"cc": cc,
"bszm_in": 1,
"bszm_out": 1,
}
res_128 = None
res_256 = None
res_512 = None
if True:
lat_128, sms_128, idx_128 = test_latencies(128, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_128)
res_128 = res.copy()
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
if n % 256 == 0:
lat_256, sms_256, idx_256 = test_latencies(256, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_256)
if lat_128 < lat_256:
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
res_256 = res.copy()
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
if n % 512 == 0:
lat_512, sms_512, idx_512 = test_latencies(512, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_512)
if lat_128 < lat_512:
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
if lat_256 < lat_512:
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
res_512 = res.copy()
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
return res_128, res_256, res_512
def tune_shape_m(K, m, k, n, device, bszm_in, bszm_out):
a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis = get_abc_m(K, m, k, n, device, bszm_in, bszm_out)
device_idx = torch.device(device).index
cc = ext.g_get_cc(device_idx)
res = {
"K": K,
"m": m,
"k": k,
"n": n,
"cc": cc,
"bszm_in": 1,
"bszm_out": 1,
}
res_128 = None
res_256 = None
res_512 = None
if True:
lat_128, sms_128, idx_128 = test_latencies_m(128, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_128, bszm_in, bszm_out)
res_128 = res.copy()
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
if n % 256 == 0:
lat_256, sms_256, idx_256 = test_latencies_m(256, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_256, bszm_in, bszm_out)
if lat_128 < lat_256:
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
res_256 = res.copy()
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
if n % 512 == 0:
lat_512, sms_512, idx_512 = test_latencies_m(512, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_512, bszm_in, bszm_out)
if lat_128 < lat_512:
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
if lat_256 < lat_512:
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
res_512 = res.copy()
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
return res_128, res_256, res_512
def tune_gemm():
out_128 = "struct TSample samples_128[] =\n{\n"
out_256 = "struct TSample samples_256[] =\n{\n"
out_512 = "struct TSample samples_512[] =\n{\n"
for device in devices:
for m in shapes_m:
for k in shapes_k:
for n in shapes_n:
for ki, K in enumerate(Ks):
res_128, res_256, res_512 = tune_shape(K, m, k, n, device)
r = res_128
if r:
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
r = res_256
if r:
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
r = res_512
if r:
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
print(out_128)
print()
print(out_256)
print()
print(out_512)
print()
def tune_mgemm():
out_128 = "struct TMSample msamples_128[] =\n{\n"
out_256 = "struct TMSample msamples_256[] =\n{\n"
out_512 = "struct TMSample msamples_512[] =\n{\n"
for device in devices:
for m in shapes_m:
for k in shapes_k:
for n in shapes_n:
for ki, K in enumerate(Ks):
for (bszm_in, bszm_out) in mgemm_bszm_io:
res_128, res_256, res_512 = tune_shape_m(K, m, k, n, device, bszm_in, bszm_out)
r = res_128
if r:
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
r = res_256
if r:
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
r = res_512
if r:
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
print(out_128)
print()
print(out_256)
print()
print(out_512)
print()
@torch.inference_mode()
def main():
tune_gemm()
# tune_mgemm()
if __name__ == "__main__":
main()