GEMM/MGEMM: Add autotuning with disk cache, remove static tuning table

This commit is contained in:
turboderp
2026-04-29 15:55:54 +02:00
parent 64f020efd0
commit 1f220f6e50
11 changed files with 892 additions and 13908 deletions

View File

@@ -7,9 +7,13 @@
#include "util.cuh"
#include "quant/exl3_devctx.cuh"
//#define GRAPHDEBUG 1
Graph::Graph()
{
ready = false;
ready_to_record = false;
disabled = false;
graph = NULL;
graph_exec = NULL;
need_cublas = false;
@@ -23,6 +27,10 @@ Graph::~Graph()
cudaStream_t Graph::capture_begin()
{
#ifdef GRAPHDEBUG
printf("Begin graph capture\n");
#endif
// Make sure nothing is pending
cudaDeviceSynchronize();
@@ -88,6 +96,10 @@ void Graph::capture_end()
// Graph is ready
ready = true;
#ifdef GRAPHDEBUG
printf("End graph capture, num_nodes=%d, graph_sites.size()=%d\n", num_nodes, graph_sites.size());
#endif
}
void Graph::record_param(void* kernel, int param_id, int param_offset)

View File

@@ -73,6 +73,8 @@ public:
bool need_cublas;
bool ready;
bool ready_to_record;
bool disabled;
Graph();
~Graph();

View File

@@ -164,13 +164,13 @@ void BC_BlockSparseMLP::run_bsz1
c10::cuda::CUDAGuard device_guard(y.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
#define USE_GRAPH
#ifndef USE_GRAPH
if (graph_bsz1.disabled || (!graph_bsz1.ready && !graph_bsz1.ready_to_record))
{
run_bsz1_gr(y, selected_experts, routing_weights, nullptr);
#else
graph_bsz1.ready_to_record = true;
}
else
{
if (!graph_bsz1.ready)
{
graph_bsz1.capture_begin();
@@ -214,9 +214,7 @@ void BC_BlockSparseMLP::run_bsz1
}
graph_bsz1.launch(args, stream);
#endif
#undef USE_GRAPH
}
}
BC_BlockSparseMLP::BC_BlockSparseMLP
@@ -452,13 +450,13 @@ void BC_BlockSparseMLP::run_single_expert
c10::cuda::CUDAGuard device_guard(y.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
#define USE_GRAPH
#ifndef USE_GRAPH
if (graph_single[graphidx].disabled || (!graph_single[graphidx].ready && !graph_single[graphidx].ready_to_record))
{
run_single_expert_gr(y, expert_idx, nullptr);
#else
graph_single[graphidx].ready_to_record = true;
}
else
{
if (!graph_single[graphidx].ready)
{
prepare_ctx(y.get_device());
@@ -486,9 +484,7 @@ void BC_BlockSparseMLP::run_single_expert
};
graph_single[graphidx].launch(args, stream);
#endif
#undef USE_GRAPH
}
}

View File

@@ -59,13 +59,13 @@ void BC_GatedMLP::run_bsz1
c10::cuda::CUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
#define USE_GRAPH
#ifndef USE_GRAPH
if (graph_bsz1.disabled || (!graph_bsz1.ready && !graph_bsz1.ready_to_record))
{
run_bsz1_gr(x, d, nullptr);
#else
graph_bsz1.ready_to_record = true;
}
else
{
if (!graph_bsz1.ready)
{
graph_bsz1.capture_begin();
@@ -80,6 +80,5 @@ void BC_GatedMLP::run_bsz1
};
graph_bsz1.launch(args, stream);
#endif
}
}

View File

@@ -0,0 +1,583 @@
#include "coop_autotune.cuh"
#include <cublas_v2.h>
#include <torch/extension.h>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <map>
#include <mutex>
#include <numeric>
#include <set>
#include <vector>
#include "../util.cuh"
#include "../util.h"
//#define CACHEDEBUG 1
namespace
{
constexpr uint64_t COOP_AUTOTUNE_VERSION = 1;
constexpr char DISK_CACHE_MAGIC[8] = { 'E', 'X', '3', 'A', 'T', 'U', 'N', 'E' };
constexpr uint32_t DISK_CACHE_FORMAT = 1;
struct ExpandedCandidate
{
void* kernel;
int block_dim;
int num_sms;
int concurrency;
int tag;
float latency;
std::vector<float> samples;
};
struct DiskCacheHeader
{
char magic[8];
uint32_t format;
uint32_t record_size;
};
struct DiskCacheRecordV1
{
uint64_t hash;
int32_t tag;
int32_t block_dim;
int32_t num_sms;
int32_t concurrency;
uint32_t reserved0;
uint32_t reserved1;
};
std::map<uint64_t, CoopAutotuneLaunch> launch_cache;
std::set<std::tuple<int, void*, size_t>> attr_set;
std::mutex disk_mutex;
bool disk_cache_loaded = false;
std::map<uint64_t, DiskCacheRecordV1> disk_cache;
uint64_t salt_hash(uint64_t hash)
{
hash ^= COOP_AUTOTUNE_VERSION;
hash *= 1099511628211ull;
return hash;
}
std::filesystem::path disk_cache_path()
{
const char* override_path = std::getenv("EXLLAMAV3_TUNE_CACHE");
if (override_path && override_path[0])
{
std::filesystem::path path = std::filesystem::path(override_path);
std::error_code ec;
if (std::filesystem::is_directory(path, ec))
return path / "coop_autotune_v1.bin";
return path;
}
std::filesystem::path base;
#ifdef _WIN32
const char* local_app_data = std::getenv("LOCALAPPDATA");
if (local_app_data && local_app_data[0])
base = std::filesystem::path(local_app_data);
else
{
const char* user_profile = std::getenv("USERPROFILE");
if (!user_profile || !user_profile[0]) return {};
std::filesystem::path user_path = std::filesystem::path(user_profile);
std::filesystem::path local_path = user_path / "AppData" / "Local";
std::error_code ec;
if (std::filesystem::exists(local_path, ec))
base = local_path;
else
base = user_path;
}
return base / "exllamav3" / "autotune" / "coop_autotune_v1.bin";
#else
const char* xdg = std::getenv("XDG_CACHE_HOME");
if (xdg && xdg[0])
base = std::filesystem::path(xdg);
else
{
const char* home = std::getenv("HOME");
if (!home || !home[0]) return {};
base = std::filesystem::path(home) / ".cache";
}
return base / "exllamav3" / "autotune" / "coop_autotune_v1.bin";
#endif
}
bool read_disk_header(std::ifstream& in)
{
DiskCacheHeader header;
in.read((char*) &header, sizeof(header));
if (!in) return false;
if (std::memcmp(header.magic, DISK_CACHE_MAGIC, sizeof(header.magic)) != 0) return false;
if (header.format != DISK_CACHE_FORMAT) return false;
if (header.record_size < sizeof(DiskCacheRecordV1)) return false;
return true;
}
void load_disk_cache_locked()
{
if (disk_cache_loaded) return;
disk_cache_loaded = true;
std::filesystem::path path = disk_cache_path();
if (path.empty() || !std::filesystem::exists(path)) return;
std::ifstream in(path, std::ios::binary);
if (!in || !read_disk_header(in)) return;
DiskCacheHeader header;
in.seekg(0, std::ios::beg);
in.read((char*) &header, sizeof(header));
if (!in) return;
std::vector<char> record_bytes(header.record_size);
while (in.read(record_bytes.data(), header.record_size))
{
DiskCacheRecordV1 record;
std::memcpy(&record, record_bytes.data(), sizeof(record));
disk_cache[record.hash] = record;
}
#ifdef CACHEDEBUG
printf
(
"coop_autotune cache loaded: path=%s records=%zu\n",
path.string().c_str(),
disk_cache.size()
);
#endif
}
void append_disk_cache_locked(const DiskCacheRecordV1& record)
{
std::filesystem::path path = disk_cache_path();
if (path.empty()) return;
std::error_code ec;
std::filesystem::create_directories(path.parent_path(), ec);
if (ec) return;
bool write_header = !std::filesystem::exists(path) || std::filesystem::file_size(path, ec) == 0;
std::ofstream out(path, std::ios::binary | std::ios::app);
if (!out) return;
if (write_header)
{
DiskCacheHeader header;
std::memcpy(header.magic, DISK_CACHE_MAGIC, sizeof(header.magic));
header.format = DISK_CACHE_FORMAT;
header.record_size = sizeof(DiskCacheRecordV1);
out.write((const char*) &header, sizeof(header));
}
out.write((const char*) &record, sizeof(record));
}
bool launch_from_disk_cache
(
uint64_t hash,
const std::vector<CoopAutotuneCandidate>& candidates,
CoopAutotuneLaunch* launch_config
)
{
std::lock_guard<std::mutex> lock(disk_mutex);
load_disk_cache_locked();
auto lookup = disk_cache.find(hash);
if (lookup == disk_cache.end()) return false;
const DiskCacheRecordV1& record = lookup->second;
#ifdef CACHEDEBUG
printf
(
"coop_autotune cache lookup: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
(unsigned long long) record.hash,
record.tag,
record.block_dim,
record.num_sms,
record.concurrency
);
#endif
for (const CoopAutotuneCandidate& candidate : candidates)
{
if (candidate.tag != record.tag) continue;
if (candidate.block_dim != record.block_dim) continue;
if (record.num_sms < 1 || record.num_sms > candidate.max_num_sms) continue;
int max_concurrency = MAX(candidate.max_concurrency, 1);
int total_sms = candidate.total_sms > 0 ? candidate.total_sms : candidate.max_num_sms;
int expected_concurrency = MAX(MIN(total_sms / record.num_sms, max_concurrency), 1);
if (record.concurrency != expected_concurrency) continue;
*launch_config =
{
candidate.kernel,
record.block_dim,
record.num_sms,
record.concurrency,
record.tag
};
#ifdef CACHEDEBUG
printf
(
"coop_autotune cache hit: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
(unsigned long long) record.hash,
launch_config->tag,
launch_config->block_dim,
launch_config->num_sms,
launch_config->concurrency
);
#endif
return true;
}
#ifdef CACHEDEBUG
printf
(
"coop_autotune cache rejected: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
(unsigned long long) record.hash,
record.tag,
record.block_dim,
record.num_sms,
record.concurrency
);
#endif
return false;
}
void store_disk_cache(uint64_t hash, const CoopAutotuneLaunch& launch_config)
{
DiskCacheRecordV1 record =
{
hash,
launch_config.tag,
launch_config.block_dim,
launch_config.num_sms,
launch_config.concurrency,
0,
0
};
std::lock_guard<std::mutex> lock(disk_mutex);
load_disk_cache_locked();
disk_cache[hash] = record;
append_disk_cache_locked(record);
#ifdef CACHEDEBUG
printf
(
"coop_autotune cache store: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
(unsigned long long) record.hash,
record.tag,
record.block_dim,
record.num_sms,
record.concurrency
);
#endif
}
void set_kernel_attr_once(void* kernel, size_t smem)
{
int device;
cuda_check(cudaGetDevice(&device));
auto key = std::make_tuple(device, kernel, smem);
if (attr_set.find(key) != attr_set.end()) return;
cuda_check(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem));
attr_set.insert(key);
}
float trimmed_mean(std::vector<float>& samples)
{
TORCH_CHECK(!samples.empty(), "CoopKernelAutotuner: no timing samples");
std::sort(samples.begin(), samples.end());
size_t trim = samples.size() / 4;
size_t begin = trim;
size_t end = samples.size() - trim;
if (begin >= end)
{
begin = 0;
end = samples.size();
}
double sum = 0.0;
for (size_t i = begin; i < end; ++i) sum += samples[i];
return (float) (sum / (double) (end - begin));
}
void measure_candidate_sample
(
ExpandedCandidate& candidate,
void** kernel_args,
size_t smem,
cudaStream_t stream,
int repeats,
cudaEvent_t start,
cudaEvent_t end
)
{
cuda_check(cudaEventRecord(start, stream));
for (int i = 0; i < repeats; ++i)
{
cuda_check(cudaLaunchCooperativeKernel
(
candidate.kernel,
dim3(candidate.num_sms, 1, candidate.concurrency),
candidate.block_dim,
kernel_args,
smem,
stream
));
}
cuda_check(cudaEventRecord(end, stream));
cuda_check(cudaEventSynchronize(end));
float ms = 0.0f;
cuda_check(cudaEventElapsedTime(&ms, start, end));
candidate.samples.push_back(ms / (float) repeats);
}
void keep_best(std::vector<ExpandedCandidate>& candidates, size_t limit)
{
for (ExpandedCandidate& candidate : candidates)
candidate.latency = trimmed_mean(candidate.samples);
std::sort
(
candidates.begin(),
candidates.end(),
[] (const ExpandedCandidate& a, const ExpandedCandidate& b)
{
return a.latency < b.latency;
}
);
if (candidates.size() > limit) candidates.resize(limit);
}
void measure_stage
(
std::vector<ExpandedCandidate>& candidates,
void** kernel_args,
size_t smem,
cudaStream_t stream,
int rounds,
int repeats,
size_t keep,
cudaEvent_t start,
cudaEvent_t end
)
{
TORCH_CHECK(!candidates.empty(), "CoopKernelAutotuner: no candidates in stage");
for (ExpandedCandidate& candidate : candidates)
{
candidate.samples.clear();
candidate.samples.reserve(rounds);
set_kernel_attr_once(candidate.kernel, smem);
// One untimed launch avoids first-use effects from contaminating the first measured round.
cuda_check(cudaLaunchCooperativeKernel
(
candidate.kernel,
dim3(candidate.num_sms, 1, candidate.concurrency),
candidate.block_dim,
kernel_args,
smem,
stream
));
}
cuda_check(cudaStreamSynchronize(stream));
std::vector<int> order(candidates.size());
std::iota(order.begin(), order.end(), 0);
for (int round = 0; round < rounds; ++round)
{
size_t n = order.size();
size_t step = 2 * (size_t) round + 1;
while (std::gcd(step, n) != 1) step += 2;
size_t pos = ((uint64_t) round * 1103515245ull + 12345ull) % n;
for (size_t i = 0; i < n; ++i)
{
measure_candidate_sample
(
candidates[order[pos]],
kernel_args,
smem,
stream,
repeats,
start,
end
);
pos = (pos + step) % n;
}
}
keep_best(candidates, keep);
}
CoopAutotuneLaunch tune
(
const std::vector<CoopAutotuneCandidate>& base_candidates,
void** kernel_args,
size_t smem,
cudaStream_t stream,
size_t numel_B
)
{
std::vector<ExpandedCandidate> candidates;
for (const CoopAutotuneCandidate& base : base_candidates)
{
TORCH_CHECK(base.kernel, "CoopKernelAutotuner: null kernel candidate");
TORCH_CHECK(base.block_dim > 0, "CoopKernelAutotuner: invalid block_dim");
TORCH_CHECK(base.max_num_sms > 0, "CoopKernelAutotuner: invalid max_num_sms");
int max_concurrency = MAX(base.max_concurrency, 1);
int total_sms = base.total_sms > 0 ? base.total_sms : base.max_num_sms;
if (max_concurrency > 1 || base.max_num_sms == 1)
{
int concurrency = MAX(MIN(total_sms, max_concurrency), 1);
candidates.push_back({ base.kernel, base.block_dim, 1, concurrency, base.tag, 0.0f, {} });
}
for (int num_sms = 2; num_sms <= base.max_num_sms * 85 / 100; num_sms += 2)
{
int concurrency = MAX(MIN(total_sms / num_sms, max_concurrency), 1);
candidates.push_back({ base.kernel, base.block_dim, num_sms, concurrency, base.tag, 0.0f, {} });
}
if (base.max_num_sms > 1)
{
int concurrency = MAX(MIN(total_sms / base.max_num_sms, max_concurrency), 1);
candidates.push_back({ base.kernel, base.block_dim, base.max_num_sms, concurrency, base.tag, 0.0f, {} });
}
}
TORCH_CHECK(!candidates.empty(), "CoopKernelAutotuner: no candidates");
cudaEvent_t start;
cudaEvent_t end;
cuda_check(cudaEventCreate(&start));
cuda_check(cudaEventCreate(&end));
int repeats = 20;
if (numel_B > 1e6) repeats = 10;
if (numel_B > 1e7) repeats = 5;
if (numel_B > 1e8) repeats = 3;
int max_rounds = 64;
if (numel_B > 1e7) max_rounds = 20;
if (numel_B > 1e8) max_rounds = 10;
measure_stage(candidates, kernel_args, smem, stream, MIN(8, max_rounds), repeats, 16, start, end);
measure_stage(candidates, kernel_args, smem, stream, MIN(40, max_rounds), repeats, 8, start, end);
measure_stage(candidates, kernel_args, smem, stream, MIN(64, max_rounds), repeats, 1, start, end);
cuda_check(cudaEventDestroy(start));
cuda_check(cudaEventDestroy(end));
const ExpandedCandidate& best = candidates[0];
return { best.kernel, best.block_dim, best.num_sms, best.concurrency, best.tag };
}
} // namespace
bool CoopKernelAutotuner::launch_locked
(
uint64_t hash,
void** kernel_args,
size_t smem,
cudaStream_t stream,
CoopAutotuneLaunch* out_launch_config
)
{
CoopAutotuneLaunch launch_config;
hash = salt_hash(hash);
{
auto lookup = launch_cache.find(hash);
if (lookup == launch_cache.end()) return false;
launch_config = lookup->second;
}
set_kernel_attr_once(launch_config.kernel, smem);
cuda_check(cudaLaunchCooperativeKernel
(
launch_config.kernel,
dim3(launch_config.num_sms, 1, launch_config.concurrency),
launch_config.block_dim,
kernel_args,
smem,
stream
));
if (out_launch_config) *out_launch_config = launch_config;
return true;
}
CoopAutotuneLaunch CoopKernelAutotuner::launch
(
uint64_t hash,
const std::vector<CoopAutotuneCandidate>& candidates,
void** kernel_args,
size_t smem,
cudaStream_t stream,
size_t numel_B
)
{
CoopAutotuneLaunch launch_config;
uint64_t salted_hash = salt_hash(hash);
if (launch_locked(hash, kernel_args, smem, stream, &launch_config))
return launch_config;
if (launch_from_disk_cache(salted_hash, candidates, &launch_config))
{
launch_cache[salted_hash] = launch_config;
set_kernel_attr_once(launch_config.kernel, smem);
cuda_check(cudaLaunchCooperativeKernel
(
launch_config.kernel,
dim3(launch_config.num_sms, 1, launch_config.concurrency),
launch_config.block_dim,
kernel_args,
smem,
stream
));
return launch_config;
}
{
launch_config = tune(candidates, kernel_args, smem, stream, numel_B);
auto inserted = launch_cache.emplace(salted_hash, launch_config);
launch_config = inserted.first->second;
}
store_disk_cache(salted_hash, launch_config);
set_kernel_attr_once(launch_config.kernel, smem);
cuda_check(cudaLaunchCooperativeKernel
(
launch_config.kernel,
dim3(launch_config.num_sms, 1, launch_config.concurrency),
launch_config.block_dim,
kernel_args,
smem,
stream
));
return launch_config;
}

View File

@@ -0,0 +1,47 @@
#pragma once
#include <cuda_runtime.h>
#include <cstdint>
#include <vector>
struct CoopAutotuneCandidate
{
void* kernel;
int block_dim;
int max_num_sms;
int max_concurrency;
int total_sms;
int tag;
};
struct CoopAutotuneLaunch
{
void* kernel;
int block_dim;
int num_sms;
int concurrency;
int tag;
};
class CoopKernelAutotuner
{
public:
static bool launch_locked
(
uint64_t hash,
void** kernel_args,
size_t smem,
cudaStream_t stream,
CoopAutotuneLaunch* launch_config = nullptr
);
static CoopAutotuneLaunch launch
(
uint64_t hash,
const std::vector<CoopAutotuneCandidate>& candidates,
void** kernel_args,
size_t smem,
cudaStream_t stream,
size_t numel_B = 1e9
);
};

View File

@@ -11,13 +11,13 @@ namespace cg = cooperative_groups;
#include "exl3_kernel_map.cuh"
#include "exl3_devctx.cuh"
#include "exl3_gemv.cuh"
#include "coop_autotune.cuh"
#include <set>
#define NEW_TUNE_GEMM
#define NEW_TUNE_MGEMM
#include <vector>
int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
int exl3_gemm_blockdim_g[] = {EXL3_GEMM_BLOCKDIM};
/*
EXL3 matmul, A @ B -> C
@@ -36,6 +36,63 @@ limitations:
std::set<void*> kernel_attr_set[MAX_DEVICES] = {};
uint64_t gemm_autotune_hash
(
int size_m,
int size_k,
int size_n,
int K,
bool c_fp32,
int device,
int cc,
int max_num_sms,
int cb
)
{
uint64_t h = 1469598103934665603ull;
auto mix = [&] (uint64_t v)
{
h ^= v;
h *= 1099511628211ull;
};
mix((uint64_t) size_m);
mix((uint64_t) size_k);
mix((uint64_t) size_n);
mix((uint64_t) K);
mix(c_fp32 ? 1ull : 0ull);
mix((uint64_t) device);
mix((uint64_t) cc);
mix((uint64_t) max_num_sms);
mix((uint64_t) cb);
return h;
}
uint64_t mgemm_autotune_hash
(
int size_m,
int size_k,
int size_n,
int K,
bool c_fp32,
int device,
int cc,
int max_num_sms,
int cb,
int bszm_in,
int bszm_out
)
{
uint64_t h = gemm_autotune_hash(size_m, size_k, size_n, K, c_fp32, device, cc, max_num_sms, cb);
auto mix = [&] (uint64_t v)
{
h ^= v;
h *= 1099511628211ull;
};
mix((uint64_t) bszm_in);
mix((uint64_t) bszm_out);
return h;
}
int exl3_gemm_gr
(
const at::Tensor& A,
@@ -108,30 +165,6 @@ int exl3_gemm_gr
int shape_idx;
fp_exl3_gemm_kernel kernel;
#ifndef NEW_TUNE_GEMM
kernel = select_exl3_gemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb
);
if (!kernel) return 0;
#else
TResult* tr = select_exl3_gemm_mgemm_kernel_new(cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
if (!tr) return 0;
num_sms = MIN(num_sms, tr->num_sms);
kernel = tr->kernel;
block_dim = tr->block_dim;
shape_idx = tr->shape_idx;
#endif
// Launch
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*) kernel);
cuda_check(cudaPeekAtLastError());
}
void* kernelArgs[] =
{
(void*)& A_ptr,
@@ -145,6 +178,79 @@ int exl3_gemm_gr
(void*)& A_had_ptr,
(void*)& svh_ptr
};
auto add_graph_args = [&](void* kernel_ptr)
{
if (graph)
{
graph->record_param(kernel_ptr, GP_gemm_A, 0);
graph->record_param(kernel_ptr, GP_gemm_B_trellis, 1);
graph->record_param(kernel_ptr, GP_gemm_C, 2);
graph->record_param(kernel_ptr, GP_gemm_B_suh, 7);
graph->record_param(kernel_ptr, GP_gemm_A_had, 8);
graph->record_param(kernel_ptr, GP_gemm_B_svh, 9);
graph->record_param(kernel_ptr, GP_end, 0);
}
};
bool autotune = force_shape_idx <= 0 && force_num_sms <= 0;
if (autotune)
{
uint64_t autotune_key = gemm_autotune_hash(MAX(size_m, 2), size_k, size_n, K, c_fp32, device, cc, num_sms, cb);
CoopAutotuneLaunch tuned;
if (CoopKernelAutotuner::launch_locked(autotune_key, kernelArgs, SMEM_MAX, stream, &tuned))
{
add_graph_args((void*) tuned.kernel);
cuda_check(cudaPeekAtLastError());
return tuned.tag;
}
std::vector<CoopAutotuneCandidate> candidates;
for (int candidate_shape_idx = 1; candidate_shape_idx <= EXL3_GEMM_NUM_SHAPES; ++candidate_shape_idx)
{
if (!exl3_gemm_shape_compat(candidate_shape_idx, size_m, size_k, size_n, K)) continue;
fp_exl3_gemm_kernel candidate_kernel = get_gemm_kernel_ptr(K, candidate_shape_idx, c_fp32, cb);
if (!candidate_kernel) continue;
int tilesize_k = exl3_gemm_tilesize_k_g[candidate_shape_idx];
int tilesize_n = exl3_gemm_tilesize_n_g[candidate_shape_idx];
int max_slices = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
int max_candidate_sms = MAX(MIN(max_slices, num_sms), 1);
candidates.push_back
({
(void*) candidate_kernel,
exl3_gemm_blockdim_g[candidate_shape_idx],
max_candidate_sms,
1,
max_candidate_sms,
candidate_shape_idx
});
}
TORCH_CHECK(!candidates.empty(), "exl3_gemm autotune: no compatible kernel shapes");
tuned = CoopKernelAutotuner::launch(autotune_key, candidates, kernelArgs, SMEM_MAX, stream, (size_t) size_k * size_n);
if (graph)
add_graph_args((void*) tuned.kernel);
cuda_check(cudaPeekAtLastError());
return tuned.tag;
}
kernel = select_exl3_gemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb
);
if (!kernel) return 0;
// Launch
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*) kernel);
cuda_check(cudaPeekAtLastError());
}
cudaLaunchCooperativeKernel
(
(void*) kernel,
@@ -154,14 +260,7 @@ int exl3_gemm_gr
SMEM_MAX,
stream
);
if (graph) graph->record_param((void*) kernel, GP_gemm_A, 0);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_trellis, 1);
if (graph) graph->record_param((void*) kernel, GP_gemm_C, 2);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_suh, 7);
if (graph) graph->record_param((void*) kernel, GP_gemm_A_had, 8);
if (graph) graph->record_param((void*) kernel, GP_gemm_B_svh, 9);
if (graph) graph->record_param((void*) kernel, GP_end, 0);
add_graph_args((void*) kernel);
cuda_check(cudaPeekAtLastError());
return shape_idx;
@@ -304,40 +403,6 @@ int exl3_mgemm_gr
fp_exl3_mgemm_kernel kernel;
int concurrency;
#ifndef NEW_TUNE_MGEMM
kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb, bszm_in, bszm_out
);
if (!kernel) return 0;
concurrency = MIN(total_sms / num_sms, bszm_out);
#else
kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb, bszm_in, bszm_out
);
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
num_sms = tiles;
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
concurrency = MIN(total_sms / num_sms, bszm);
#endif
// Launch bigger grid if possible
dim3 block_grid(num_sms, 1, concurrency);
// Launch
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*) kernel);
}
void* kernelArgs[] =
{
(void*)& A_ptr,
@@ -358,6 +423,96 @@ int exl3_mgemm_gr
(void*)& max_index
};
auto add_graph_args = [&](void* kernel_ptr)
{
if (graph)
{
graph->record_param(kernel_ptr, GP_mgemm_A, 0);
graph->record_param(kernel_ptr, GP_mgemm_C, 2);
graph->record_param(kernel_ptr, GP_mgemm_indices, 10);
graph->record_param(kernel_ptr, GP_mgemm_weights, 11);
graph->record_param(kernel_ptr, GP_end, 0);
}
};
bool autotune = force_shape_idx <= 0 && force_num_sms <= 0;
if (autotune)
{
uint64_t autotune_key = mgemm_autotune_hash
(
size_m, size_k, size_n, K, c_fp32, device, cc, total_sms, cb, bszm_in, bszm_out
);
CoopAutotuneLaunch tuned;
if (CoopKernelAutotuner::launch_locked(autotune_key, kernelArgs, SMEM_MAX, stream, &tuned))
{
add_graph_args((void*) tuned.kernel);
cuda_check(cudaPeekAtLastError());
return tuned.tag;
}
if (!graph)
{
std::vector<CoopAutotuneCandidate> candidates;
for (int candidate_shape_idx = 1; candidate_shape_idx <= EXL3_GEMM_NUM_SHAPES; ++candidate_shape_idx)
{
if (!exl3_gemm_shape_compat(candidate_shape_idx, size_m, size_k, size_n, K)) continue;
fp_exl3_mgemm_kernel candidate_kernel = get_mgemm_kernel_ptr(K, candidate_shape_idx, c_fp32, cb);
if (!candidate_kernel) continue;
int tilesize_k = exl3_gemm_tilesize_k_g[candidate_shape_idx];
int tilesize_n = exl3_gemm_tilesize_n_g[candidate_shape_idx];
int max_slices = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
int max_candidate_sms = MAX(MIN(max_slices, total_sms), 1);
candidates.push_back
({
(void*) candidate_kernel,
exl3_gemm_blockdim_g[candidate_shape_idx],
max_candidate_sms,
bszm,
total_sms,
candidate_shape_idx
});
}
TORCH_CHECK(!candidates.empty(), "exl3_mgemm autotune: no compatible kernel shapes");
tuned = CoopKernelAutotuner::launch(autotune_key, candidates, kernelArgs, SMEM_MAX, stream, (size_t) size_k * size_n * bszm);
add_graph_args((void*) tuned.kernel);
// DBGI10(size_m, size_k, size_n, K, bszm_in, bszm_out, tuned.tag, tuned.block_dim, tuned.num_sms, tuned.concurrency);
cuda_check(cudaPeekAtLastError());
return tuned.tag;
}
}
kernel = select_exl3_mgemm_kernel
(
cc, size_m, size_k, size_n, K, c_fp32,
force_shape_idx, &block_dim, &shape_idx,
&num_sms, cb, bszm_in, bszm_out
);
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
num_sms = tiles;
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
concurrency = MIN(total_sms / num_sms, bszm);
// DBGI10(size_m, size_k, size_n, K, bszm_in, bszm_out, shape_idx, block_dim, num_sms, concurrency);
// Launch bigger grid if possible
dim3 block_grid(num_sms, 1, concurrency);
// Launch
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
{
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
kernel_attr_set[device].insert((void*) kernel);
}
cudaLaunchCooperativeKernel
(
(void*) kernel,
@@ -367,12 +522,7 @@ int exl3_mgemm_gr
SMEM_MAX,
stream
);
if (graph) graph->record_param((void*) kernel, GP_mgemm_A, 0);
if (graph) graph->record_param((void*) kernel, GP_mgemm_C, 2);
if (graph) graph->record_param((void*) kernel, GP_mgemm_indices, 10);
if (graph) graph->record_param((void*) kernel, GP_mgemm_weights, 11);
if (graph) graph->record_param((void*) kernel, GP_end, 0);
add_graph_args((void*) kernel);
cuda_check(cudaPeekAtLastError());
return shape_idx;

View File

@@ -20,9 +20,6 @@ namespace cg = cooperative_groups;
#include "comp_units/exl3_comp_unit_7.cuh"
#include "comp_units/exl3_comp_unit_8.cuh"
#include "exl3_kernel_map_samples.cuh"
std::map<uint64_t, TResult> _tuning_cache = {};
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int K, bool multi, int bszm_in, int bszm_out)
{
bool mod_256 = (size_n % 256 == 0);
@@ -295,85 +292,3 @@ fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int
}
}
}
TResult f_tr;
TResult* select_exl3_gemm_mgemm_kernel_new
(
int cc,
int size_m,
int size_k,
int size_n,
int K,
bool c_fp32,
int force_shape_idx,
int force_num_sms,
int cb
)
{
// Force parameters for tuning/benchmarking
if (force_shape_idx > 0)
{
TORCH_CHECK(force_num_sms, "Must supply force_shape_idx and force_num_sms together");
f_tr.kernel = get_gemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
f_tr.mkernel = get_mgemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
f_tr.shape_idx = force_shape_idx;
f_tr.num_sms = force_num_sms;
f_tr.block_dim = exl3_gemm_blockdim[force_shape_idx];
return &f_tr;
};
TORCH_CHECK(!force_num_sms, "Must supply force_shape_idx and force_num_sms together.");
// Cache parameters
uint64_t key = (((uint64_t) size_k) << 40) |
(((uint64_t) size_n) << 16) |
(((uint64_t) cc) << 8) |
(((uint64_t) K) << 4) |
(c_fp32 ? 0x01ull : 0x00ull);
auto lookup = _tuning_cache.find(key);
if (lookup == _tuning_cache.end())
{
// Find closest kernel in map
bool mod512 = (size_n % 512 == 0);
bool mod256 = (size_n % 256 == 0);
bool mod128 = (size_n % 128 == 0);
TORCH_CHECK(mod128, "size_n must be a multiple of 128");
TSample* cand = mod512 ? samples_512 : (mod256 ? samples_256 : samples_128);
TSample* best = nullptr;
int64_t best_dist = 1ll<<62;
for (; cand->K; cand++)
{
if (cand->K != K) continue;
if (cand->cc != cc) continue;
int64_t distk = (int64_t) (size_k - cand->k);
int64_t distn = (int64_t) (size_n - cand->n);
int64_t dist = distk * distk + distn * distn;
if (dist < best_dist) { best_dist = dist; best = cand; }
}
TORCH_CHECK(best, "Failed to find valid kernel for shape");
// Avoid empty blocks
int tilesize_k = exl3_gemm_tilesize_k[best->shape_idx];
int tilesize_n = exl3_gemm_tilesize_n[best->shape_idx];
int max_slices = size_k / tilesize_k * size_n / tilesize_n;
int num_sms = MAX(MIN(max_slices, best->num_sms), 1);
// Results
TResult tr = {
get_gemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
get_mgemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
best->shape_idx,
num_sms,
exl3_gemm_blockdim[best->shape_idx]
};
_tuning_cache[key] = tr;
}
lookup = _tuning_cache.find(key);
return &(lookup->second);
}

View File

@@ -138,46 +138,5 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
const int bszm_out
);
struct TSample {
int cc;
int K;
int m;
int k;
int n;
int shape_idx;
int num_sms;
};
struct TMSample {
int cc;
int K;
int m;
int k;
int n;
int shape_idx;
int num_sms;
int bszm_in;
int bszm_out;
};
struct TResult
{
fp_exl3_gemm_kernel kernel;
fp_exl3_mgemm_kernel mkernel;
int shape_idx;
int num_sms;
int block_dim;
};
TResult* select_exl3_gemm_mgemm_kernel_new
(
const int cc,
const int size_m,
const int size_k,
const int size_n,
const int K,
const bool c_fp32,
const int force_shape_idx,
const int force_num_sms,
const int cb
);
fp_exl3_gemm_kernel get_gemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb);
fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb);

File diff suppressed because it is too large Load Diff

View File

@@ -1,561 +0,0 @@
import sys, os
from collections import OrderedDict
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from exllamav3.ext import exllamav3_ext as ext
from exllamav3.util import Timer
from exllamav3.util.memory import free_mem
from tabulate import tabulate
import numpy as np
num_warmup_passes = 10
num_benchmark_iter_a = 20
num_benchmark_iter_b = 40
outlier_trim = 0.3
assume_cache = 384 * 1024 ** 2
devices = [
"cuda:1",
"cuda:2",
"cuda:3",
]
shapes_m = [1]
shapes_k = [
128,
256,
512,
1024,
2048,
3072,
4096,
5120,
8192,
12288,
14336,
16384,
24576,
]
shapes_n = [
128,
256,
512,
1024,
2048,
3072,
4096,
5120,
8192,
12288,
14336,
16384,
24576,
51200,
128000,
]
shape_indices_128 = [1, 2]
shape_indices_256 = [3]
shape_indices_512 = [4]
Ks = [1, 2, 3, 4, 5, 6, 7, 8]
mgemm_bszm_io = [
(1, 8),
(8, 1),
]
g_spin = 0
def get_abc(K, m, k, n, device):
proto_a = torch.randn((m, k), dtype = torch.half, device = device)
proto_b = torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device)
proto_c = torch.zeros((m, n), dtype = torch.half, device = device)
proto_suh = torch.randn((k,), dtype = torch.half, device = device)
proto_svh = torch.randn((n,), dtype = torch.half, device = device)
# Create enough clones to cycle through to prevent L2 cache from skewing results
proto_size = proto_a.numel() * 2 + proto_b.numel() * 2 + proto_c.numel() * 2
num_buffers = max(assume_cache // proto_size + 1, 2)
a = [proto_a.clone() for _ in range(num_buffers)]
b = [proto_b.clone() for _ in range(num_buffers)]
c = [proto_c.clone() for _ in range(num_buffers)]
suh = [proto_suh.clone() for _ in range(num_buffers)]
svh = [proto_svh.clone() for _ in range(num_buffers)]
return a, b, c, suh, svh
def get_abc_m(K, m, k, n, device, bszm_in, bszm_out):
bszm = max(bszm_in, bszm_out)
proto_a = torch.randn((bszm_in, m, k), dtype = torch.half, device = device)
proto_b = [torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device) for _ in range(bszm)]
proto_c = torch.zeros((bszm_out, m, n), dtype = torch.half, device = device)
proto_suh = [torch.randn((k,), dtype = torch.half, device = device) for _ in range(bszm)]
proto_svh = [torch.randn((n,), dtype = torch.half, device = device) for _ in range(bszm)]
# Create enough clones to cycle through to prevent L2 cache from skewing results
proto_size = proto_a.numel() * 2 + sum(p.numel() for p in proto_b) * 2 + proto_c.numel() * 2
num_buffers = max(assume_cache // proto_size + 1, 2)
a = [proto_a.clone() for _ in range(num_buffers)]
b = [[proto_b_.clone() for proto_b_ in proto_b] for _ in range(num_buffers)]
c = [proto_c.clone() for _ in range(num_buffers)]
suh = [[proto_suh_.clone() for proto_suh_ in proto_suh] for _ in range(num_buffers)]
svh = [[proto_svh_.clone() for proto_svh_ in proto_svh] for _ in range(num_buffers)]
trellis = b
ptrs_suh = [torch.tensor([suh__.data_ptr() for suh__ in suh_], dtype = torch.long, device = device) for suh_ in suh]
ptrs_svh = [torch.tensor([svh__.data_ptr() for svh__ in svh_], dtype = torch.long, device = device) for svh_ in svh]
ptrs_trellis = [torch.tensor([trellis__.data_ptr() for trellis__ in trellis_], dtype = torch.long, device = device) for trellis_ in trellis]
return a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis
def warmup(a, b, c, suh, svh, shape_idx):
global g_spin
num_buffers = len(a)
for _ in range(num_warmup_passes):
i = g_spin % num_buffers
g_spin += 1
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, 64)
def warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis):
num_buffers = len(a)
num_exp = ptrs_trellis[0].shape[0]
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
K = b[0][0].shape[-1] // 16
for i_ in range(num_warmup_passes):
i = i_ % num_buffers
ext.exl3_mgemm(
a[i],
ptrs_trellis[i],
c[i],
ptrs_suh[i],
a[i],
ptrs_svh[i],
m_indices,
None,
K,
-1,
0,
0,
-1,
-1,
0
)
def benchmark(a, b, c, suh, svh, shape_idx, num_iter, num_sms):
num_buffers = len(a)
dummy = c[0][0, 0].item()
with Timer() as t:
for i_ in range(num_iter):
i = i_ % num_buffers
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
dummy = c[i][0, 0].item()
mean_time_ms = t.interval / num_iter * 1000
return mean_time_ms
def benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_iter, num_sms, trim = outlier_trim, stream = None):
global g_spin
device = a[0].device
if stream is None:
stream = torch.cuda.current_stream(device)
torch.cuda.synchronize(device)
# Precreate events to reduce overhead jitter
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
# Timed loop
for it in range(num_iter):
i = g_spin % len(a)
g_spin += 1
starts[it].record(stream)
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
stops[it].record(stream)
# Ensure all recorded events are complete before reading times
torch.cuda.synchronize(device)
# Collect per-iteration latency in milliseconds (GPU time)
per_ms = np.array([starts[it].elapsed_time(stops[it]) for it in range(num_iter)], dtype = np.float64)
# Robust stats
median = float(np.median(per_ms))
mean = float(per_ms.mean())
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
# Simple symmetric trimming
if 0.0 < trim < 0.5 and num_iter > 4:
lo = np.quantile(per_ms, trim)
hi = np.quantile(per_ms, 1.0 - trim)
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
else:
trimmed = per_ms
trimmed_mean = mean
return {
"per_launch_ms": per_ms, # numpy array of length num_iter
"mean_ms": mean,
"median_ms": median,
"std_ms": std,
"trimmed_mean_ms": trimmed_mean,
"trim_bounds_ms": (
float(lo) if 'lo' in locals() else None,
float(hi) if 'hi' in locals() else None
),
"kept_count": int(trimmed.size),
"total_count": int(num_iter),
}
def benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_iter, num_sms, trim = outlier_trim, stream = None):
device = a[0].device
if stream is None:
stream = torch.cuda.current_stream(device)
torch.cuda.synchronize(device)
num_exp = ptrs_trellis[0].shape[0]
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
K = b[0][0].shape[-1] // 16
# Precreate events to reduce overhead jitter
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
# Timed loop
for it in range(num_iter):
i = it % len(a)
starts[it].record(stream)
for _ in range(2):
ext.exl3_mgemm(
a[i],
ptrs_trellis[i],
c[i],
ptrs_suh[i],
a[i],
ptrs_svh[i],
m_indices,
None,
K,
shape_idx,
0,
0,
-1,
-1,
num_sms
)
stops[it].record(stream)
# Ensure all recorded events are complete before reading times
torch.cuda.synchronize(device)
# Collect per-iteration latency in milliseconds (GPU time)
per_ms = np.array([starts[it].elapsed_time(stops[it]) / 2 for it in range(num_iter)], dtype = np.float64)
# Robust stats
median = float(np.median(per_ms))
mean = float(per_ms.mean())
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
# Simple symmetric trimming
if 0.0 < trim < 0.5 and num_iter > 4:
lo = np.quantile(per_ms, trim)
hi = np.quantile(per_ms, 1.0 - trim)
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
else:
trimmed = per_ms
trimmed_mean = mean
return {
"per_launch_ms": per_ms, # numpy array of length num_iter
"mean_ms": mean,
"median_ms": median,
"std_ms": std,
"trimmed_mean_ms": trimmed_mean,
"trim_bounds_ms": (
float(lo) if 'lo' in locals() else None,
float(hi) if 'hi' in locals() else None
),
"kept_count": int(trimmed.size),
"total_count": int(num_iter),
}
def get_floor(hist_sms):
best = float("inf"), 0, 0
for l, s, i in hist_sms:
if l < best[0]:
best = l, s, i
for l, s, i in hist_sms:
if l < best[0] * 1.008:
return l, s, i
return best
def test_latencies(mod, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices):
g_hist = []
for shape_idx in shape_indices:
max_sms = ext.g_get_num_sms(device_idx)
warmup(a, b, c, suh, svh, shape_idx)
hist_sms = []
for num_sms in range(0, max_sms + 1, 8):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
num_sms_ = max(num_sms, 1)
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, num_sms_)
mean = stats["trimmed_mean_ms"]
hist_sms.append((mean, num_sms_, shape_idx))
_, best_sms, _ = get_floor(hist_sms)
sms_a = best_sms - 8
sms_b = best_sms + 8
if sms_a < 0:
sms_a = 0
sms_b = 16
if sms_b > num_sms + 4:
sms_b = num_sms + 4
sms_a = sms_b - 16
warmup(a, b, c, suh, svh, shape_idx)
for num_sms in range(sms_a, sms_b + 1, 2):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
num_sms_ = min(max(num_sms, 1), max_sms)
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, num_sms_)
mean = stats["trimmed_mean_ms"]
g_hist.append((mean, num_sms_, shape_idx))
g_hist = sorted(g_hist, key = lambda t: t[1])
best_g_lat, best_g_sms, best_g_idx = get_floor(g_hist)
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}")
return best_g_lat, best_g_sms, best_g_idx
def test_latencies_m(mod, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices, bszm_in, bszm_out):
best_g_lat = float("inf")
best_g_sms = None
best_g_idx = None
best_l_lat = float("inf")
best_l_sms = None
best_l_idx = None
for shape_idx in shape_indices:
# Skip incompatible shapes
# if not ext.exl3_gemm_shape_compat(shape_idx, m, k, n, K):
# continue
max_sms = ext.g_get_num_sms(device_idx)
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
best_sms_lat = float("inf")
for num_sms in range(0, max_sms + 1, 8):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_a, max(num_sms, 1))
mean = stats["trimmed_mean_ms"]
if mean < best_sms_lat:
best_sms_lat = mean
best_sms = max(num_sms, 1)
sms_a = best_sms - 8
sms_b = best_sms + 8
if sms_a < 0:
sms_a = 0
sms_b = 16
if sms_b > num_sms + 4:
sms_b = num_sms + 4
sms_a = sms_b - 16
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
for num_sms in range(sms_a, sms_b + 1, 2):
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
mean = stats["trimmed_mean_ms"]
if mean < best_g_lat:
best_g_lat = mean
best_g_sms = min(max(num_sms, 1), max_sms)
best_g_idx = shape_idx
if mean < best_l_lat:
best_l_lat = mean
best_l_sms = min(max(num_sms, 1), max_sms)
best_l_idx = shape_idx
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_l_lat:8.5f} ms, num_sms {best_l_sms:3}, shape_idx {best_l_idx}, i/o {bszm_in}/{bszm_out}")
print()
print(f" --> mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}, i/o {bszm_in}/{bszm_out}")
print()
return best_g_lat, best_g_sms, best_g_idx
def tune_shape(K, m, k, n, device):
a, b, c, suh, svh = get_abc(K, m, k, n, device)
device_idx = torch.device(device).index
cc = ext.g_get_cc(device_idx)
res = {
"K": K,
"m": m,
"k": k,
"n": n,
"cc": cc,
"bszm_in": 1,
"bszm_out": 1,
}
res_128 = None
res_256 = None
res_512 = None
if True:
lat_128, sms_128, idx_128 = test_latencies(128, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_128)
res_128 = res.copy()
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
if n % 256 == 0:
lat_256, sms_256, idx_256 = test_latencies(256, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_256)
if lat_128 < lat_256:
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
res_256 = res.copy()
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
if n % 512 == 0:
lat_512, sms_512, idx_512 = test_latencies(512, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_512)
if lat_128 < lat_512:
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
if lat_256 < lat_512:
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
res_512 = res.copy()
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
return res_128, res_256, res_512
def tune_shape_m(K, m, k, n, device, bszm_in, bszm_out):
a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis = get_abc_m(K, m, k, n, device, bszm_in, bszm_out)
device_idx = torch.device(device).index
cc = ext.g_get_cc(device_idx)
res = {
"K": K,
"m": m,
"k": k,
"n": n,
"cc": cc,
"bszm_in": 1,
"bszm_out": 1,
}
res_128 = None
res_256 = None
res_512 = None
if True:
lat_128, sms_128, idx_128 = test_latencies_m(128, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_128, bszm_in, bszm_out)
res_128 = res.copy()
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
if n % 256 == 0:
lat_256, sms_256, idx_256 = test_latencies_m(256, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_256, bszm_in, bszm_out)
if lat_128 < lat_256:
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
res_256 = res.copy()
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
if n % 512 == 0:
lat_512, sms_512, idx_512 = test_latencies_m(512, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_512, bszm_in, bszm_out)
if lat_128 < lat_512:
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
if lat_256 < lat_512:
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
res_512 = res.copy()
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
return res_128, res_256, res_512
def tune_gemm():
out_128 = "struct TSample samples_128[] =\n{\n"
out_256 = "struct TSample samples_256[] =\n{\n"
out_512 = "struct TSample samples_512[] =\n{\n"
for device in devices:
for m in shapes_m:
for k in shapes_k:
for n in shapes_n:
for ki, K in enumerate(Ks):
res_128, res_256, res_512 = tune_shape(K, m, k, n, device)
r = res_128
if r:
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
r = res_256
if r:
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
r = res_512
if r:
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
print(out_128)
print()
print(out_256)
print()
print(out_512)
print()
def tune_mgemm():
out_128 = "struct TMSample msamples_128[] =\n{\n"
out_256 = "struct TMSample msamples_256[] =\n{\n"
out_512 = "struct TMSample msamples_512[] =\n{\n"
for device in devices:
for m in shapes_m:
for k in shapes_k:
for n in shapes_n:
for ki, K in enumerate(Ks):
for (bszm_in, bszm_out) in mgemm_bszm_io:
res_128, res_256, res_512 = tune_shape_m(K, m, k, n, device, bszm_in, bszm_out)
r = res_128
if r:
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
r = res_256
if r:
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
r = res_512
if r:
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
print(out_128)
print()
print(out_256)
print()
print(out_512)
print()
@torch.inference_mode()
def main():
tune_gemm()
# tune_mgemm()
if __name__ == "__main__":
main()