mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-05-11 08:20:05 +00:00
GEMM/MGEMM: Add autotuning with disk cache, remove static tuning table
This commit is contained in:
@@ -7,9 +7,13 @@
|
||||
#include "util.cuh"
|
||||
#include "quant/exl3_devctx.cuh"
|
||||
|
||||
//#define GRAPHDEBUG 1
|
||||
|
||||
Graph::Graph()
|
||||
{
|
||||
ready = false;
|
||||
ready_to_record = false;
|
||||
disabled = false;
|
||||
graph = NULL;
|
||||
graph_exec = NULL;
|
||||
need_cublas = false;
|
||||
@@ -23,6 +27,10 @@ Graph::~Graph()
|
||||
|
||||
cudaStream_t Graph::capture_begin()
|
||||
{
|
||||
#ifdef GRAPHDEBUG
|
||||
printf("Begin graph capture\n");
|
||||
#endif
|
||||
|
||||
// Make sure nothing is pending
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@@ -88,6 +96,10 @@ void Graph::capture_end()
|
||||
|
||||
// Graph is ready
|
||||
ready = true;
|
||||
|
||||
#ifdef GRAPHDEBUG
|
||||
printf("End graph capture, num_nodes=%d, graph_sites.size()=%d\n", num_nodes, graph_sites.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
void Graph::record_param(void* kernel, int param_id, int param_offset)
|
||||
|
||||
@@ -73,6 +73,8 @@ public:
|
||||
|
||||
bool need_cublas;
|
||||
bool ready;
|
||||
bool ready_to_record;
|
||||
bool disabled;
|
||||
|
||||
Graph();
|
||||
~Graph();
|
||||
|
||||
@@ -164,13 +164,13 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
c10::cuda::CUDAGuard device_guard(y.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
#define USE_GRAPH
|
||||
#ifndef USE_GRAPH
|
||||
|
||||
if (graph_bsz1.disabled || (!graph_bsz1.ready && !graph_bsz1.ready_to_record))
|
||||
{
|
||||
run_bsz1_gr(y, selected_experts, routing_weights, nullptr);
|
||||
|
||||
#else
|
||||
|
||||
graph_bsz1.ready_to_record = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!graph_bsz1.ready)
|
||||
{
|
||||
graph_bsz1.capture_begin();
|
||||
@@ -214,9 +214,7 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
}
|
||||
|
||||
graph_bsz1.launch(args, stream);
|
||||
|
||||
#endif
|
||||
#undef USE_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
BC_BlockSparseMLP::BC_BlockSparseMLP
|
||||
@@ -452,13 +450,13 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
c10::cuda::CUDAGuard device_guard(y.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
#define USE_GRAPH
|
||||
#ifndef USE_GRAPH
|
||||
|
||||
if (graph_single[graphidx].disabled || (!graph_single[graphidx].ready && !graph_single[graphidx].ready_to_record))
|
||||
{
|
||||
run_single_expert_gr(y, expert_idx, nullptr);
|
||||
|
||||
#else
|
||||
|
||||
graph_single[graphidx].ready_to_record = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!graph_single[graphidx].ready)
|
||||
{
|
||||
prepare_ctx(y.get_device());
|
||||
@@ -486,9 +484,7 @@ void BC_BlockSparseMLP::run_single_expert
|
||||
};
|
||||
|
||||
graph_single[graphidx].launch(args, stream);
|
||||
|
||||
#endif
|
||||
#undef USE_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -59,13 +59,13 @@ void BC_GatedMLP::run_bsz1
|
||||
c10::cuda::CUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
#define USE_GRAPH
|
||||
#ifndef USE_GRAPH
|
||||
|
||||
if (graph_bsz1.disabled || (!graph_bsz1.ready && !graph_bsz1.ready_to_record))
|
||||
{
|
||||
run_bsz1_gr(x, d, nullptr);
|
||||
|
||||
#else
|
||||
|
||||
graph_bsz1.ready_to_record = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!graph_bsz1.ready)
|
||||
{
|
||||
graph_bsz1.capture_begin();
|
||||
@@ -80,6 +80,5 @@ void BC_GatedMLP::run_bsz1
|
||||
};
|
||||
|
||||
graph_bsz1.launch(args, stream);
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
583
exllamav3/exllamav3_ext/quant/coop_autotune.cu
Normal file
583
exllamav3/exllamav3_ext/quant/coop_autotune.cu
Normal file
@@ -0,0 +1,583 @@
|
||||
#include "coop_autotune.cuh"
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "../util.cuh"
|
||||
#include "../util.h"
|
||||
|
||||
//#define CACHEDEBUG 1
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
constexpr uint64_t COOP_AUTOTUNE_VERSION = 1;
|
||||
constexpr char DISK_CACHE_MAGIC[8] = { 'E', 'X', '3', 'A', 'T', 'U', 'N', 'E' };
|
||||
constexpr uint32_t DISK_CACHE_FORMAT = 1;
|
||||
|
||||
struct ExpandedCandidate
|
||||
{
|
||||
void* kernel;
|
||||
int block_dim;
|
||||
int num_sms;
|
||||
int concurrency;
|
||||
int tag;
|
||||
float latency;
|
||||
std::vector<float> samples;
|
||||
};
|
||||
|
||||
struct DiskCacheHeader
|
||||
{
|
||||
char magic[8];
|
||||
uint32_t format;
|
||||
uint32_t record_size;
|
||||
};
|
||||
|
||||
struct DiskCacheRecordV1
|
||||
{
|
||||
uint64_t hash;
|
||||
int32_t tag;
|
||||
int32_t block_dim;
|
||||
int32_t num_sms;
|
||||
int32_t concurrency;
|
||||
uint32_t reserved0;
|
||||
uint32_t reserved1;
|
||||
};
|
||||
|
||||
std::map<uint64_t, CoopAutotuneLaunch> launch_cache;
|
||||
|
||||
std::set<std::tuple<int, void*, size_t>> attr_set;
|
||||
|
||||
std::mutex disk_mutex;
|
||||
bool disk_cache_loaded = false;
|
||||
std::map<uint64_t, DiskCacheRecordV1> disk_cache;
|
||||
|
||||
uint64_t salt_hash(uint64_t hash)
|
||||
{
|
||||
hash ^= COOP_AUTOTUNE_VERSION;
|
||||
hash *= 1099511628211ull;
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::filesystem::path disk_cache_path()
|
||||
{
|
||||
const char* override_path = std::getenv("EXLLAMAV3_TUNE_CACHE");
|
||||
if (override_path && override_path[0])
|
||||
{
|
||||
std::filesystem::path path = std::filesystem::path(override_path);
|
||||
std::error_code ec;
|
||||
if (std::filesystem::is_directory(path, ec))
|
||||
return path / "coop_autotune_v1.bin";
|
||||
return path;
|
||||
}
|
||||
|
||||
std::filesystem::path base;
|
||||
#ifdef _WIN32
|
||||
const char* local_app_data = std::getenv("LOCALAPPDATA");
|
||||
if (local_app_data && local_app_data[0])
|
||||
base = std::filesystem::path(local_app_data);
|
||||
else
|
||||
{
|
||||
const char* user_profile = std::getenv("USERPROFILE");
|
||||
if (!user_profile || !user_profile[0]) return {};
|
||||
std::filesystem::path user_path = std::filesystem::path(user_profile);
|
||||
std::filesystem::path local_path = user_path / "AppData" / "Local";
|
||||
std::error_code ec;
|
||||
if (std::filesystem::exists(local_path, ec))
|
||||
base = local_path;
|
||||
else
|
||||
base = user_path;
|
||||
}
|
||||
|
||||
return base / "exllamav3" / "autotune" / "coop_autotune_v1.bin";
|
||||
#else
|
||||
const char* xdg = std::getenv("XDG_CACHE_HOME");
|
||||
if (xdg && xdg[0])
|
||||
base = std::filesystem::path(xdg);
|
||||
else
|
||||
{
|
||||
const char* home = std::getenv("HOME");
|
||||
if (!home || !home[0]) return {};
|
||||
base = std::filesystem::path(home) / ".cache";
|
||||
}
|
||||
|
||||
return base / "exllamav3" / "autotune" / "coop_autotune_v1.bin";
|
||||
#endif
|
||||
}
|
||||
|
||||
bool read_disk_header(std::ifstream& in)
|
||||
{
|
||||
DiskCacheHeader header;
|
||||
in.read((char*) &header, sizeof(header));
|
||||
if (!in) return false;
|
||||
if (std::memcmp(header.magic, DISK_CACHE_MAGIC, sizeof(header.magic)) != 0) return false;
|
||||
if (header.format != DISK_CACHE_FORMAT) return false;
|
||||
if (header.record_size < sizeof(DiskCacheRecordV1)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void load_disk_cache_locked()
|
||||
{
|
||||
if (disk_cache_loaded) return;
|
||||
disk_cache_loaded = true;
|
||||
|
||||
std::filesystem::path path = disk_cache_path();
|
||||
if (path.empty() || !std::filesystem::exists(path)) return;
|
||||
|
||||
std::ifstream in(path, std::ios::binary);
|
||||
if (!in || !read_disk_header(in)) return;
|
||||
|
||||
DiskCacheHeader header;
|
||||
in.seekg(0, std::ios::beg);
|
||||
in.read((char*) &header, sizeof(header));
|
||||
if (!in) return;
|
||||
|
||||
std::vector<char> record_bytes(header.record_size);
|
||||
while (in.read(record_bytes.data(), header.record_size))
|
||||
{
|
||||
DiskCacheRecordV1 record;
|
||||
std::memcpy(&record, record_bytes.data(), sizeof(record));
|
||||
disk_cache[record.hash] = record;
|
||||
}
|
||||
|
||||
#ifdef CACHEDEBUG
|
||||
printf
|
||||
(
|
||||
"coop_autotune cache loaded: path=%s records=%zu\n",
|
||||
path.string().c_str(),
|
||||
disk_cache.size()
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
void append_disk_cache_locked(const DiskCacheRecordV1& record)
|
||||
{
|
||||
std::filesystem::path path = disk_cache_path();
|
||||
if (path.empty()) return;
|
||||
|
||||
std::error_code ec;
|
||||
std::filesystem::create_directories(path.parent_path(), ec);
|
||||
if (ec) return;
|
||||
|
||||
bool write_header = !std::filesystem::exists(path) || std::filesystem::file_size(path, ec) == 0;
|
||||
std::ofstream out(path, std::ios::binary | std::ios::app);
|
||||
if (!out) return;
|
||||
|
||||
if (write_header)
|
||||
{
|
||||
DiskCacheHeader header;
|
||||
std::memcpy(header.magic, DISK_CACHE_MAGIC, sizeof(header.magic));
|
||||
header.format = DISK_CACHE_FORMAT;
|
||||
header.record_size = sizeof(DiskCacheRecordV1);
|
||||
out.write((const char*) &header, sizeof(header));
|
||||
}
|
||||
|
||||
out.write((const char*) &record, sizeof(record));
|
||||
}
|
||||
|
||||
bool launch_from_disk_cache
|
||||
(
|
||||
uint64_t hash,
|
||||
const std::vector<CoopAutotuneCandidate>& candidates,
|
||||
CoopAutotuneLaunch* launch_config
|
||||
)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(disk_mutex);
|
||||
load_disk_cache_locked();
|
||||
|
||||
auto lookup = disk_cache.find(hash);
|
||||
if (lookup == disk_cache.end()) return false;
|
||||
|
||||
const DiskCacheRecordV1& record = lookup->second;
|
||||
#ifdef CACHEDEBUG
|
||||
printf
|
||||
(
|
||||
"coop_autotune cache lookup: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
|
||||
(unsigned long long) record.hash,
|
||||
record.tag,
|
||||
record.block_dim,
|
||||
record.num_sms,
|
||||
record.concurrency
|
||||
);
|
||||
#endif
|
||||
|
||||
for (const CoopAutotuneCandidate& candidate : candidates)
|
||||
{
|
||||
if (candidate.tag != record.tag) continue;
|
||||
if (candidate.block_dim != record.block_dim) continue;
|
||||
if (record.num_sms < 1 || record.num_sms > candidate.max_num_sms) continue;
|
||||
|
||||
int max_concurrency = MAX(candidate.max_concurrency, 1);
|
||||
int total_sms = candidate.total_sms > 0 ? candidate.total_sms : candidate.max_num_sms;
|
||||
int expected_concurrency = MAX(MIN(total_sms / record.num_sms, max_concurrency), 1);
|
||||
if (record.concurrency != expected_concurrency) continue;
|
||||
|
||||
*launch_config =
|
||||
{
|
||||
candidate.kernel,
|
||||
record.block_dim,
|
||||
record.num_sms,
|
||||
record.concurrency,
|
||||
record.tag
|
||||
};
|
||||
#ifdef CACHEDEBUG
|
||||
printf
|
||||
(
|
||||
"coop_autotune cache hit: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
|
||||
(unsigned long long) record.hash,
|
||||
launch_config->tag,
|
||||
launch_config->block_dim,
|
||||
launch_config->num_sms,
|
||||
launch_config->concurrency
|
||||
);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef CACHEDEBUG
|
||||
printf
|
||||
(
|
||||
"coop_autotune cache rejected: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
|
||||
(unsigned long long) record.hash,
|
||||
record.tag,
|
||||
record.block_dim,
|
||||
record.num_sms,
|
||||
record.concurrency
|
||||
);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
void store_disk_cache(uint64_t hash, const CoopAutotuneLaunch& launch_config)
|
||||
{
|
||||
DiskCacheRecordV1 record =
|
||||
{
|
||||
hash,
|
||||
launch_config.tag,
|
||||
launch_config.block_dim,
|
||||
launch_config.num_sms,
|
||||
launch_config.concurrency,
|
||||
0,
|
||||
0
|
||||
};
|
||||
|
||||
std::lock_guard<std::mutex> lock(disk_mutex);
|
||||
load_disk_cache_locked();
|
||||
disk_cache[hash] = record;
|
||||
append_disk_cache_locked(record);
|
||||
|
||||
#ifdef CACHEDEBUG
|
||||
printf
|
||||
(
|
||||
"coop_autotune cache store: hash=%016llx tag=%d block_dim=%d num_sms=%d concurrency=%d\n",
|
||||
(unsigned long long) record.hash,
|
||||
record.tag,
|
||||
record.block_dim,
|
||||
record.num_sms,
|
||||
record.concurrency
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
void set_kernel_attr_once(void* kernel, size_t smem)
|
||||
{
|
||||
int device;
|
||||
cuda_check(cudaGetDevice(&device));
|
||||
|
||||
auto key = std::make_tuple(device, kernel, smem);
|
||||
if (attr_set.find(key) != attr_set.end()) return;
|
||||
|
||||
cuda_check(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem));
|
||||
attr_set.insert(key);
|
||||
}
|
||||
|
||||
float trimmed_mean(std::vector<float>& samples)
|
||||
{
|
||||
TORCH_CHECK(!samples.empty(), "CoopKernelAutotuner: no timing samples");
|
||||
std::sort(samples.begin(), samples.end());
|
||||
|
||||
size_t trim = samples.size() / 4;
|
||||
size_t begin = trim;
|
||||
size_t end = samples.size() - trim;
|
||||
if (begin >= end)
|
||||
{
|
||||
begin = 0;
|
||||
end = samples.size();
|
||||
}
|
||||
|
||||
double sum = 0.0;
|
||||
for (size_t i = begin; i < end; ++i) sum += samples[i];
|
||||
return (float) (sum / (double) (end - begin));
|
||||
}
|
||||
|
||||
void measure_candidate_sample
|
||||
(
|
||||
ExpandedCandidate& candidate,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
int repeats,
|
||||
cudaEvent_t start,
|
||||
cudaEvent_t end
|
||||
)
|
||||
{
|
||||
cuda_check(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < repeats; ++i)
|
||||
{
|
||||
cuda_check(cudaLaunchCooperativeKernel
|
||||
(
|
||||
candidate.kernel,
|
||||
dim3(candidate.num_sms, 1, candidate.concurrency),
|
||||
candidate.block_dim,
|
||||
kernel_args,
|
||||
smem,
|
||||
stream
|
||||
));
|
||||
}
|
||||
cuda_check(cudaEventRecord(end, stream));
|
||||
cuda_check(cudaEventSynchronize(end));
|
||||
|
||||
float ms = 0.0f;
|
||||
cuda_check(cudaEventElapsedTime(&ms, start, end));
|
||||
candidate.samples.push_back(ms / (float) repeats);
|
||||
}
|
||||
|
||||
void keep_best(std::vector<ExpandedCandidate>& candidates, size_t limit)
|
||||
{
|
||||
for (ExpandedCandidate& candidate : candidates)
|
||||
candidate.latency = trimmed_mean(candidate.samples);
|
||||
|
||||
std::sort
|
||||
(
|
||||
candidates.begin(),
|
||||
candidates.end(),
|
||||
[] (const ExpandedCandidate& a, const ExpandedCandidate& b)
|
||||
{
|
||||
return a.latency < b.latency;
|
||||
}
|
||||
);
|
||||
if (candidates.size() > limit) candidates.resize(limit);
|
||||
}
|
||||
|
||||
void measure_stage
|
||||
(
|
||||
std::vector<ExpandedCandidate>& candidates,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
int rounds,
|
||||
int repeats,
|
||||
size_t keep,
|
||||
cudaEvent_t start,
|
||||
cudaEvent_t end
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(!candidates.empty(), "CoopKernelAutotuner: no candidates in stage");
|
||||
|
||||
for (ExpandedCandidate& candidate : candidates)
|
||||
{
|
||||
candidate.samples.clear();
|
||||
candidate.samples.reserve(rounds);
|
||||
set_kernel_attr_once(candidate.kernel, smem);
|
||||
|
||||
// One untimed launch avoids first-use effects from contaminating the first measured round.
|
||||
cuda_check(cudaLaunchCooperativeKernel
|
||||
(
|
||||
candidate.kernel,
|
||||
dim3(candidate.num_sms, 1, candidate.concurrency),
|
||||
candidate.block_dim,
|
||||
kernel_args,
|
||||
smem,
|
||||
stream
|
||||
));
|
||||
}
|
||||
cuda_check(cudaStreamSynchronize(stream));
|
||||
|
||||
std::vector<int> order(candidates.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
|
||||
for (int round = 0; round < rounds; ++round)
|
||||
{
|
||||
size_t n = order.size();
|
||||
size_t step = 2 * (size_t) round + 1;
|
||||
while (std::gcd(step, n) != 1) step += 2;
|
||||
size_t pos = ((uint64_t) round * 1103515245ull + 12345ull) % n;
|
||||
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
{
|
||||
measure_candidate_sample
|
||||
(
|
||||
candidates[order[pos]],
|
||||
kernel_args,
|
||||
smem,
|
||||
stream,
|
||||
repeats,
|
||||
start,
|
||||
end
|
||||
);
|
||||
pos = (pos + step) % n;
|
||||
}
|
||||
}
|
||||
|
||||
keep_best(candidates, keep);
|
||||
}
|
||||
|
||||
CoopAutotuneLaunch tune
|
||||
(
|
||||
const std::vector<CoopAutotuneCandidate>& base_candidates,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
size_t numel_B
|
||||
)
|
||||
{
|
||||
std::vector<ExpandedCandidate> candidates;
|
||||
for (const CoopAutotuneCandidate& base : base_candidates)
|
||||
{
|
||||
TORCH_CHECK(base.kernel, "CoopKernelAutotuner: null kernel candidate");
|
||||
TORCH_CHECK(base.block_dim > 0, "CoopKernelAutotuner: invalid block_dim");
|
||||
TORCH_CHECK(base.max_num_sms > 0, "CoopKernelAutotuner: invalid max_num_sms");
|
||||
int max_concurrency = MAX(base.max_concurrency, 1);
|
||||
int total_sms = base.total_sms > 0 ? base.total_sms : base.max_num_sms;
|
||||
|
||||
if (max_concurrency > 1 || base.max_num_sms == 1)
|
||||
{
|
||||
int concurrency = MAX(MIN(total_sms, max_concurrency), 1);
|
||||
candidates.push_back({ base.kernel, base.block_dim, 1, concurrency, base.tag, 0.0f, {} });
|
||||
}
|
||||
|
||||
for (int num_sms = 2; num_sms <= base.max_num_sms * 85 / 100; num_sms += 2)
|
||||
{
|
||||
int concurrency = MAX(MIN(total_sms / num_sms, max_concurrency), 1);
|
||||
candidates.push_back({ base.kernel, base.block_dim, num_sms, concurrency, base.tag, 0.0f, {} });
|
||||
}
|
||||
|
||||
if (base.max_num_sms > 1)
|
||||
{
|
||||
int concurrency = MAX(MIN(total_sms / base.max_num_sms, max_concurrency), 1);
|
||||
candidates.push_back({ base.kernel, base.block_dim, base.max_num_sms, concurrency, base.tag, 0.0f, {} });
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(!candidates.empty(), "CoopKernelAutotuner: no candidates");
|
||||
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t end;
|
||||
cuda_check(cudaEventCreate(&start));
|
||||
cuda_check(cudaEventCreate(&end));
|
||||
|
||||
int repeats = 20;
|
||||
if (numel_B > 1e6) repeats = 10;
|
||||
if (numel_B > 1e7) repeats = 5;
|
||||
if (numel_B > 1e8) repeats = 3;
|
||||
int max_rounds = 64;
|
||||
if (numel_B > 1e7) max_rounds = 20;
|
||||
if (numel_B > 1e8) max_rounds = 10;
|
||||
|
||||
measure_stage(candidates, kernel_args, smem, stream, MIN(8, max_rounds), repeats, 16, start, end);
|
||||
measure_stage(candidates, kernel_args, smem, stream, MIN(40, max_rounds), repeats, 8, start, end);
|
||||
measure_stage(candidates, kernel_args, smem, stream, MIN(64, max_rounds), repeats, 1, start, end);
|
||||
|
||||
cuda_check(cudaEventDestroy(start));
|
||||
cuda_check(cudaEventDestroy(end));
|
||||
|
||||
const ExpandedCandidate& best = candidates[0];
|
||||
return { best.kernel, best.block_dim, best.num_sms, best.concurrency, best.tag };
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool CoopKernelAutotuner::launch_locked
|
||||
(
|
||||
uint64_t hash,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
CoopAutotuneLaunch* out_launch_config
|
||||
)
|
||||
{
|
||||
CoopAutotuneLaunch launch_config;
|
||||
hash = salt_hash(hash);
|
||||
|
||||
{
|
||||
auto lookup = launch_cache.find(hash);
|
||||
if (lookup == launch_cache.end()) return false;
|
||||
launch_config = lookup->second;
|
||||
}
|
||||
|
||||
set_kernel_attr_once(launch_config.kernel, smem);
|
||||
cuda_check(cudaLaunchCooperativeKernel
|
||||
(
|
||||
launch_config.kernel,
|
||||
dim3(launch_config.num_sms, 1, launch_config.concurrency),
|
||||
launch_config.block_dim,
|
||||
kernel_args,
|
||||
smem,
|
||||
stream
|
||||
));
|
||||
|
||||
if (out_launch_config) *out_launch_config = launch_config;
|
||||
return true;
|
||||
}
|
||||
|
||||
CoopAutotuneLaunch CoopKernelAutotuner::launch
|
||||
(
|
||||
uint64_t hash,
|
||||
const std::vector<CoopAutotuneCandidate>& candidates,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
size_t numel_B
|
||||
)
|
||||
{
|
||||
CoopAutotuneLaunch launch_config;
|
||||
uint64_t salted_hash = salt_hash(hash);
|
||||
|
||||
if (launch_locked(hash, kernel_args, smem, stream, &launch_config))
|
||||
return launch_config;
|
||||
|
||||
if (launch_from_disk_cache(salted_hash, candidates, &launch_config))
|
||||
{
|
||||
launch_cache[salted_hash] = launch_config;
|
||||
set_kernel_attr_once(launch_config.kernel, smem);
|
||||
cuda_check(cudaLaunchCooperativeKernel
|
||||
(
|
||||
launch_config.kernel,
|
||||
dim3(launch_config.num_sms, 1, launch_config.concurrency),
|
||||
launch_config.block_dim,
|
||||
kernel_args,
|
||||
smem,
|
||||
stream
|
||||
));
|
||||
return launch_config;
|
||||
}
|
||||
|
||||
{
|
||||
launch_config = tune(candidates, kernel_args, smem, stream, numel_B);
|
||||
auto inserted = launch_cache.emplace(salted_hash, launch_config);
|
||||
launch_config = inserted.first->second;
|
||||
}
|
||||
store_disk_cache(salted_hash, launch_config);
|
||||
|
||||
set_kernel_attr_once(launch_config.kernel, smem);
|
||||
cuda_check(cudaLaunchCooperativeKernel
|
||||
(
|
||||
launch_config.kernel,
|
||||
dim3(launch_config.num_sms, 1, launch_config.concurrency),
|
||||
launch_config.block_dim,
|
||||
kernel_args,
|
||||
smem,
|
||||
stream
|
||||
));
|
||||
|
||||
return launch_config;
|
||||
}
|
||||
47
exllamav3/exllamav3_ext/quant/coop_autotune.cuh
Normal file
47
exllamav3/exllamav3_ext/quant/coop_autotune.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
struct CoopAutotuneCandidate
|
||||
{
|
||||
void* kernel;
|
||||
int block_dim;
|
||||
int max_num_sms;
|
||||
int max_concurrency;
|
||||
int total_sms;
|
||||
int tag;
|
||||
};
|
||||
|
||||
struct CoopAutotuneLaunch
|
||||
{
|
||||
void* kernel;
|
||||
int block_dim;
|
||||
int num_sms;
|
||||
int concurrency;
|
||||
int tag;
|
||||
};
|
||||
|
||||
class CoopKernelAutotuner
|
||||
{
|
||||
public:
|
||||
static bool launch_locked
|
||||
(
|
||||
uint64_t hash,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
CoopAutotuneLaunch* launch_config = nullptr
|
||||
);
|
||||
|
||||
static CoopAutotuneLaunch launch
|
||||
(
|
||||
uint64_t hash,
|
||||
const std::vector<CoopAutotuneCandidate>& candidates,
|
||||
void** kernel_args,
|
||||
size_t smem,
|
||||
cudaStream_t stream,
|
||||
size_t numel_B = 1e9
|
||||
);
|
||||
};
|
||||
@@ -11,13 +11,13 @@ namespace cg = cooperative_groups;
|
||||
#include "exl3_kernel_map.cuh"
|
||||
#include "exl3_devctx.cuh"
|
||||
#include "exl3_gemv.cuh"
|
||||
#include "coop_autotune.cuh"
|
||||
#include <set>
|
||||
|
||||
#define NEW_TUNE_GEMM
|
||||
#define NEW_TUNE_MGEMM
|
||||
#include <vector>
|
||||
|
||||
int exl3_gemm_tilesize_k_g[] = {EXL3_GEMM_TILESIZE_K};
|
||||
int exl3_gemm_tilesize_n_g[] = {EXL3_GEMM_TILESIZE_N};
|
||||
int exl3_gemm_blockdim_g[] = {EXL3_GEMM_BLOCKDIM};
|
||||
|
||||
/*
|
||||
EXL3 matmul, A @ B -> C
|
||||
@@ -36,6 +36,63 @@ limitations:
|
||||
|
||||
std::set<void*> kernel_attr_set[MAX_DEVICES] = {};
|
||||
|
||||
uint64_t gemm_autotune_hash
|
||||
(
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int device,
|
||||
int cc,
|
||||
int max_num_sms,
|
||||
int cb
|
||||
)
|
||||
{
|
||||
uint64_t h = 1469598103934665603ull;
|
||||
auto mix = [&] (uint64_t v)
|
||||
{
|
||||
h ^= v;
|
||||
h *= 1099511628211ull;
|
||||
};
|
||||
mix((uint64_t) size_m);
|
||||
mix((uint64_t) size_k);
|
||||
mix((uint64_t) size_n);
|
||||
mix((uint64_t) K);
|
||||
mix(c_fp32 ? 1ull : 0ull);
|
||||
mix((uint64_t) device);
|
||||
mix((uint64_t) cc);
|
||||
mix((uint64_t) max_num_sms);
|
||||
mix((uint64_t) cb);
|
||||
return h;
|
||||
}
|
||||
|
||||
uint64_t mgemm_autotune_hash
|
||||
(
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int device,
|
||||
int cc,
|
||||
int max_num_sms,
|
||||
int cb,
|
||||
int bszm_in,
|
||||
int bszm_out
|
||||
)
|
||||
{
|
||||
uint64_t h = gemm_autotune_hash(size_m, size_k, size_n, K, c_fp32, device, cc, max_num_sms, cb);
|
||||
auto mix = [&] (uint64_t v)
|
||||
{
|
||||
h ^= v;
|
||||
h *= 1099511628211ull;
|
||||
};
|
||||
mix((uint64_t) bszm_in);
|
||||
mix((uint64_t) bszm_out);
|
||||
return h;
|
||||
}
|
||||
|
||||
int exl3_gemm_gr
|
||||
(
|
||||
const at::Tensor& A,
|
||||
@@ -108,30 +165,6 @@ int exl3_gemm_gr
|
||||
int shape_idx;
|
||||
fp_exl3_gemm_kernel kernel;
|
||||
|
||||
#ifndef NEW_TUNE_GEMM
|
||||
kernel = select_exl3_gemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
#else
|
||||
TResult* tr = select_exl3_gemm_mgemm_kernel_new(cc, size_m, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb);
|
||||
if (!tr) return 0;
|
||||
num_sms = MIN(num_sms, tr->num_sms);
|
||||
kernel = tr->kernel;
|
||||
block_dim = tr->block_dim;
|
||||
shape_idx = tr->shape_idx;
|
||||
#endif
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
void* kernelArgs[] =
|
||||
{
|
||||
(void*)& A_ptr,
|
||||
@@ -145,6 +178,79 @@ int exl3_gemm_gr
|
||||
(void*)& A_had_ptr,
|
||||
(void*)& svh_ptr
|
||||
};
|
||||
|
||||
auto add_graph_args = [&](void* kernel_ptr)
|
||||
{
|
||||
if (graph)
|
||||
{
|
||||
graph->record_param(kernel_ptr, GP_gemm_A, 0);
|
||||
graph->record_param(kernel_ptr, GP_gemm_B_trellis, 1);
|
||||
graph->record_param(kernel_ptr, GP_gemm_C, 2);
|
||||
graph->record_param(kernel_ptr, GP_gemm_B_suh, 7);
|
||||
graph->record_param(kernel_ptr, GP_gemm_A_had, 8);
|
||||
graph->record_param(kernel_ptr, GP_gemm_B_svh, 9);
|
||||
graph->record_param(kernel_ptr, GP_end, 0);
|
||||
}
|
||||
};
|
||||
|
||||
bool autotune = force_shape_idx <= 0 && force_num_sms <= 0;
|
||||
if (autotune)
|
||||
{
|
||||
uint64_t autotune_key = gemm_autotune_hash(MAX(size_m, 2), size_k, size_n, K, c_fp32, device, cc, num_sms, cb);
|
||||
CoopAutotuneLaunch tuned;
|
||||
if (CoopKernelAutotuner::launch_locked(autotune_key, kernelArgs, SMEM_MAX, stream, &tuned))
|
||||
{
|
||||
add_graph_args((void*) tuned.kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return tuned.tag;
|
||||
}
|
||||
std::vector<CoopAutotuneCandidate> candidates;
|
||||
for (int candidate_shape_idx = 1; candidate_shape_idx <= EXL3_GEMM_NUM_SHAPES; ++candidate_shape_idx)
|
||||
{
|
||||
if (!exl3_gemm_shape_compat(candidate_shape_idx, size_m, size_k, size_n, K)) continue;
|
||||
|
||||
fp_exl3_gemm_kernel candidate_kernel = get_gemm_kernel_ptr(K, candidate_shape_idx, c_fp32, cb);
|
||||
if (!candidate_kernel) continue;
|
||||
|
||||
int tilesize_k = exl3_gemm_tilesize_k_g[candidate_shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n_g[candidate_shape_idx];
|
||||
int max_slices = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
|
||||
int max_candidate_sms = MAX(MIN(max_slices, num_sms), 1);
|
||||
|
||||
candidates.push_back
|
||||
({
|
||||
(void*) candidate_kernel,
|
||||
exl3_gemm_blockdim_g[candidate_shape_idx],
|
||||
max_candidate_sms,
|
||||
1,
|
||||
max_candidate_sms,
|
||||
candidate_shape_idx
|
||||
});
|
||||
}
|
||||
TORCH_CHECK(!candidates.empty(), "exl3_gemm autotune: no compatible kernel shapes");
|
||||
|
||||
tuned = CoopKernelAutotuner::launch(autotune_key, candidates, kernelArgs, SMEM_MAX, stream, (size_t) size_k * size_n);
|
||||
if (graph)
|
||||
add_graph_args((void*) tuned.kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return tuned.tag;
|
||||
}
|
||||
|
||||
kernel = select_exl3_gemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
cudaLaunchCooperativeKernel
|
||||
(
|
||||
(void*) kernel,
|
||||
@@ -154,14 +260,7 @@ int exl3_gemm_gr
|
||||
SMEM_MAX,
|
||||
stream
|
||||
);
|
||||
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_A, 0);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_trellis, 1);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_C, 2);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_suh, 7);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_A_had, 8);
|
||||
if (graph) graph->record_param((void*) kernel, GP_gemm_B_svh, 9);
|
||||
if (graph) graph->record_param((void*) kernel, GP_end, 0);
|
||||
add_graph_args((void*) kernel);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return shape_idx;
|
||||
@@ -304,40 +403,6 @@ int exl3_mgemm_gr
|
||||
fp_exl3_mgemm_kernel kernel;
|
||||
int concurrency;
|
||||
|
||||
#ifndef NEW_TUNE_MGEMM
|
||||
kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
if (!kernel) return 0;
|
||||
concurrency = MIN(total_sms / num_sms, bszm_out);
|
||||
#else
|
||||
kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
|
||||
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
|
||||
num_sms = tiles;
|
||||
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
|
||||
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
|
||||
concurrency = MIN(total_sms / num_sms, bszm);
|
||||
#endif
|
||||
|
||||
// Launch bigger grid if possible
|
||||
dim3 block_grid(num_sms, 1, concurrency);
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
}
|
||||
void* kernelArgs[] =
|
||||
{
|
||||
(void*)& A_ptr,
|
||||
@@ -358,6 +423,96 @@ int exl3_mgemm_gr
|
||||
(void*)& max_index
|
||||
};
|
||||
|
||||
auto add_graph_args = [&](void* kernel_ptr)
|
||||
{
|
||||
if (graph)
|
||||
{
|
||||
graph->record_param(kernel_ptr, GP_mgemm_A, 0);
|
||||
graph->record_param(kernel_ptr, GP_mgemm_C, 2);
|
||||
graph->record_param(kernel_ptr, GP_mgemm_indices, 10);
|
||||
graph->record_param(kernel_ptr, GP_mgemm_weights, 11);
|
||||
graph->record_param(kernel_ptr, GP_end, 0);
|
||||
}
|
||||
};
|
||||
|
||||
bool autotune = force_shape_idx <= 0 && force_num_sms <= 0;
|
||||
if (autotune)
|
||||
{
|
||||
uint64_t autotune_key = mgemm_autotune_hash
|
||||
(
|
||||
size_m, size_k, size_n, K, c_fp32, device, cc, total_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
|
||||
CoopAutotuneLaunch tuned;
|
||||
if (CoopKernelAutotuner::launch_locked(autotune_key, kernelArgs, SMEM_MAX, stream, &tuned))
|
||||
{
|
||||
add_graph_args((void*) tuned.kernel);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return tuned.tag;
|
||||
}
|
||||
if (!graph)
|
||||
{
|
||||
std::vector<CoopAutotuneCandidate> candidates;
|
||||
for (int candidate_shape_idx = 1; candidate_shape_idx <= EXL3_GEMM_NUM_SHAPES; ++candidate_shape_idx)
|
||||
{
|
||||
if (!exl3_gemm_shape_compat(candidate_shape_idx, size_m, size_k, size_n, K)) continue;
|
||||
|
||||
fp_exl3_mgemm_kernel candidate_kernel = get_mgemm_kernel_ptr(K, candidate_shape_idx, c_fp32, cb);
|
||||
if (!candidate_kernel) continue;
|
||||
|
||||
int tilesize_k = exl3_gemm_tilesize_k_g[candidate_shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n_g[candidate_shape_idx];
|
||||
int max_slices = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
|
||||
int max_candidate_sms = MAX(MIN(max_slices, total_sms), 1);
|
||||
|
||||
candidates.push_back
|
||||
({
|
||||
(void*) candidate_kernel,
|
||||
exl3_gemm_blockdim_g[candidate_shape_idx],
|
||||
max_candidate_sms,
|
||||
bszm,
|
||||
total_sms,
|
||||
candidate_shape_idx
|
||||
});
|
||||
}
|
||||
TORCH_CHECK(!candidates.empty(), "exl3_mgemm autotune: no compatible kernel shapes");
|
||||
|
||||
tuned = CoopKernelAutotuner::launch(autotune_key, candidates, kernelArgs, SMEM_MAX, stream, (size_t) size_k * size_n * bszm);
|
||||
add_graph_args((void*) tuned.kernel);
|
||||
|
||||
// DBGI10(size_m, size_k, size_n, K, bszm_in, bszm_out, tuned.tag, tuned.block_dim, tuned.num_sms, tuned.concurrency);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return tuned.tag;
|
||||
}
|
||||
}
|
||||
|
||||
kernel = select_exl3_mgemm_kernel
|
||||
(
|
||||
cc, size_m, size_k, size_n, K, c_fp32,
|
||||
force_shape_idx, &block_dim, &shape_idx,
|
||||
&num_sms, cb, bszm_in, bszm_out
|
||||
);
|
||||
int tilesize_k = exl3_gemm_tilesize_k_g[shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n_g[shape_idx];
|
||||
int tiles = MAX(size_k / tilesize_k * size_n / tilesize_n, 1);
|
||||
num_sms = tiles;
|
||||
if (num_sms * bszm > total_sms) num_sms = MAX(total_sms / bszm, 1);
|
||||
if (num_sms <= total_sms && tiles / num_sms > 48) num_sms = MIN(total_sms, num_sms * 2);
|
||||
concurrency = MIN(total_sms / num_sms, bszm);
|
||||
|
||||
// DBGI10(size_m, size_k, size_n, K, bszm_in, bszm_out, shape_idx, block_dim, num_sms, concurrency);
|
||||
|
||||
// Launch bigger grid if possible
|
||||
dim3 block_grid(num_sms, 1, concurrency);
|
||||
|
||||
// Launch
|
||||
if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end())
|
||||
{
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_MAX);
|
||||
kernel_attr_set[device].insert((void*) kernel);
|
||||
}
|
||||
|
||||
cudaLaunchCooperativeKernel
|
||||
(
|
||||
(void*) kernel,
|
||||
@@ -367,12 +522,7 @@ int exl3_mgemm_gr
|
||||
SMEM_MAX,
|
||||
stream
|
||||
);
|
||||
|
||||
if (graph) graph->record_param((void*) kernel, GP_mgemm_A, 0);
|
||||
if (graph) graph->record_param((void*) kernel, GP_mgemm_C, 2);
|
||||
if (graph) graph->record_param((void*) kernel, GP_mgemm_indices, 10);
|
||||
if (graph) graph->record_param((void*) kernel, GP_mgemm_weights, 11);
|
||||
if (graph) graph->record_param((void*) kernel, GP_end, 0);
|
||||
add_graph_args((void*) kernel);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
return shape_idx;
|
||||
|
||||
@@ -20,9 +20,6 @@ namespace cg = cooperative_groups;
|
||||
#include "comp_units/exl3_comp_unit_7.cuh"
|
||||
#include "comp_units/exl3_comp_unit_8.cuh"
|
||||
|
||||
#include "exl3_kernel_map_samples.cuh"
|
||||
std::map<uint64_t, TResult> _tuning_cache = {};
|
||||
|
||||
int select_gemm_shape(int cc, int size_m, int size_k, int size_n, int K, bool multi, int bszm_in, int bszm_out)
|
||||
{
|
||||
bool mod_256 = (size_n % 256 == 0);
|
||||
@@ -295,85 +292,3 @@ fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TResult f_tr;
|
||||
|
||||
TResult* select_exl3_gemm_mgemm_kernel_new
|
||||
(
|
||||
int cc,
|
||||
int size_m,
|
||||
int size_k,
|
||||
int size_n,
|
||||
int K,
|
||||
bool c_fp32,
|
||||
int force_shape_idx,
|
||||
int force_num_sms,
|
||||
int cb
|
||||
)
|
||||
{
|
||||
// Force parameters for tuning/benchmarking
|
||||
if (force_shape_idx > 0)
|
||||
{
|
||||
TORCH_CHECK(force_num_sms, "Must supply force_shape_idx and force_num_sms together");
|
||||
f_tr.kernel = get_gemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
|
||||
f_tr.mkernel = get_mgemm_kernel_ptr(K, force_shape_idx, c_fp32, cb);
|
||||
f_tr.shape_idx = force_shape_idx;
|
||||
f_tr.num_sms = force_num_sms;
|
||||
f_tr.block_dim = exl3_gemm_blockdim[force_shape_idx];
|
||||
return &f_tr;
|
||||
};
|
||||
TORCH_CHECK(!force_num_sms, "Must supply force_shape_idx and force_num_sms together.");
|
||||
|
||||
// Cache parameters
|
||||
uint64_t key = (((uint64_t) size_k) << 40) |
|
||||
(((uint64_t) size_n) << 16) |
|
||||
(((uint64_t) cc) << 8) |
|
||||
(((uint64_t) K) << 4) |
|
||||
(c_fp32 ? 0x01ull : 0x00ull);
|
||||
|
||||
auto lookup = _tuning_cache.find(key);
|
||||
if (lookup == _tuning_cache.end())
|
||||
{
|
||||
// Find closest kernel in map
|
||||
bool mod512 = (size_n % 512 == 0);
|
||||
bool mod256 = (size_n % 256 == 0);
|
||||
bool mod128 = (size_n % 128 == 0);
|
||||
TORCH_CHECK(mod128, "size_n must be a multiple of 128");
|
||||
TSample* cand = mod512 ? samples_512 : (mod256 ? samples_256 : samples_128);
|
||||
TSample* best = nullptr;
|
||||
int64_t best_dist = 1ll<<62;
|
||||
|
||||
for (; cand->K; cand++)
|
||||
{
|
||||
if (cand->K != K) continue;
|
||||
if (cand->cc != cc) continue;
|
||||
|
||||
int64_t distk = (int64_t) (size_k - cand->k);
|
||||
int64_t distn = (int64_t) (size_n - cand->n);
|
||||
int64_t dist = distk * distk + distn * distn;
|
||||
if (dist < best_dist) { best_dist = dist; best = cand; }
|
||||
}
|
||||
TORCH_CHECK(best, "Failed to find valid kernel for shape");
|
||||
|
||||
// Avoid empty blocks
|
||||
int tilesize_k = exl3_gemm_tilesize_k[best->shape_idx];
|
||||
int tilesize_n = exl3_gemm_tilesize_n[best->shape_idx];
|
||||
int max_slices = size_k / tilesize_k * size_n / tilesize_n;
|
||||
int num_sms = MAX(MIN(max_slices, best->num_sms), 1);
|
||||
|
||||
// Results
|
||||
TResult tr = {
|
||||
get_gemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
|
||||
get_mgemm_kernel_ptr(K, best->shape_idx, c_fp32, cb),
|
||||
best->shape_idx,
|
||||
num_sms,
|
||||
exl3_gemm_blockdim[best->shape_idx]
|
||||
};
|
||||
|
||||
_tuning_cache[key] = tr;
|
||||
}
|
||||
|
||||
lookup = _tuning_cache.find(key);
|
||||
return &(lookup->second);
|
||||
}
|
||||
@@ -138,46 +138,5 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel
|
||||
const int bszm_out
|
||||
);
|
||||
|
||||
struct TSample {
|
||||
int cc;
|
||||
int K;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
};
|
||||
|
||||
struct TMSample {
|
||||
int cc;
|
||||
int K;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
int bszm_in;
|
||||
int bszm_out;
|
||||
};
|
||||
|
||||
struct TResult
|
||||
{
|
||||
fp_exl3_gemm_kernel kernel;
|
||||
fp_exl3_mgemm_kernel mkernel;
|
||||
int shape_idx;
|
||||
int num_sms;
|
||||
int block_dim;
|
||||
};
|
||||
|
||||
TResult* select_exl3_gemm_mgemm_kernel_new
|
||||
(
|
||||
const int cc,
|
||||
const int size_m,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int K,
|
||||
const bool c_fp32,
|
||||
const int force_shape_idx,
|
||||
const int force_num_sms,
|
||||
const int cb
|
||||
);
|
||||
fp_exl3_gemm_kernel get_gemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb);
|
||||
fp_exl3_mgemm_kernel get_mgemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,561 +0,0 @@
|
||||
import sys, os
|
||||
from collections import OrderedDict
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import torch
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
from exllamav3.util import Timer
|
||||
from exllamav3.util.memory import free_mem
|
||||
from tabulate import tabulate
|
||||
import numpy as np
|
||||
|
||||
num_warmup_passes = 10
|
||||
num_benchmark_iter_a = 20
|
||||
num_benchmark_iter_b = 40
|
||||
outlier_trim = 0.3
|
||||
assume_cache = 384 * 1024 ** 2
|
||||
|
||||
devices = [
|
||||
"cuda:1",
|
||||
"cuda:2",
|
||||
"cuda:3",
|
||||
]
|
||||
|
||||
shapes_m = [1]
|
||||
|
||||
shapes_k = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
8192,
|
||||
12288,
|
||||
14336,
|
||||
16384,
|
||||
24576,
|
||||
]
|
||||
|
||||
shapes_n = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
8192,
|
||||
12288,
|
||||
14336,
|
||||
16384,
|
||||
24576,
|
||||
51200,
|
||||
128000,
|
||||
]
|
||||
|
||||
shape_indices_128 = [1, 2]
|
||||
shape_indices_256 = [3]
|
||||
shape_indices_512 = [4]
|
||||
|
||||
Ks = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
|
||||
mgemm_bszm_io = [
|
||||
(1, 8),
|
||||
(8, 1),
|
||||
]
|
||||
|
||||
g_spin = 0
|
||||
|
||||
def get_abc(K, m, k, n, device):
|
||||
proto_a = torch.randn((m, k), dtype = torch.half, device = device)
|
||||
proto_b = torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device)
|
||||
proto_c = torch.zeros((m, n), dtype = torch.half, device = device)
|
||||
proto_suh = torch.randn((k,), dtype = torch.half, device = device)
|
||||
proto_svh = torch.randn((n,), dtype = torch.half, device = device)
|
||||
|
||||
# Create enough clones to cycle through to prevent L2 cache from skewing results
|
||||
proto_size = proto_a.numel() * 2 + proto_b.numel() * 2 + proto_c.numel() * 2
|
||||
num_buffers = max(assume_cache // proto_size + 1, 2)
|
||||
a = [proto_a.clone() for _ in range(num_buffers)]
|
||||
b = [proto_b.clone() for _ in range(num_buffers)]
|
||||
c = [proto_c.clone() for _ in range(num_buffers)]
|
||||
suh = [proto_suh.clone() for _ in range(num_buffers)]
|
||||
svh = [proto_svh.clone() for _ in range(num_buffers)]
|
||||
return a, b, c, suh, svh
|
||||
|
||||
|
||||
def get_abc_m(K, m, k, n, device, bszm_in, bszm_out):
|
||||
bszm = max(bszm_in, bszm_out)
|
||||
proto_a = torch.randn((bszm_in, m, k), dtype = torch.half, device = device)
|
||||
proto_b = [torch.zeros((k // 16, n // 16, 16 * K), dtype = torch.short, device = device) for _ in range(bszm)]
|
||||
proto_c = torch.zeros((bszm_out, m, n), dtype = torch.half, device = device)
|
||||
proto_suh = [torch.randn((k,), dtype = torch.half, device = device) for _ in range(bszm)]
|
||||
proto_svh = [torch.randn((n,), dtype = torch.half, device = device) for _ in range(bszm)]
|
||||
|
||||
# Create enough clones to cycle through to prevent L2 cache from skewing results
|
||||
proto_size = proto_a.numel() * 2 + sum(p.numel() for p in proto_b) * 2 + proto_c.numel() * 2
|
||||
num_buffers = max(assume_cache // proto_size + 1, 2)
|
||||
a = [proto_a.clone() for _ in range(num_buffers)]
|
||||
b = [[proto_b_.clone() for proto_b_ in proto_b] for _ in range(num_buffers)]
|
||||
c = [proto_c.clone() for _ in range(num_buffers)]
|
||||
suh = [[proto_suh_.clone() for proto_suh_ in proto_suh] for _ in range(num_buffers)]
|
||||
svh = [[proto_svh_.clone() for proto_svh_ in proto_svh] for _ in range(num_buffers)]
|
||||
trellis = b
|
||||
ptrs_suh = [torch.tensor([suh__.data_ptr() for suh__ in suh_], dtype = torch.long, device = device) for suh_ in suh]
|
||||
ptrs_svh = [torch.tensor([svh__.data_ptr() for svh__ in svh_], dtype = torch.long, device = device) for svh_ in svh]
|
||||
ptrs_trellis = [torch.tensor([trellis__.data_ptr() for trellis__ in trellis_], dtype = torch.long, device = device) for trellis_ in trellis]
|
||||
return a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis
|
||||
|
||||
|
||||
def warmup(a, b, c, suh, svh, shape_idx):
|
||||
global g_spin
|
||||
num_buffers = len(a)
|
||||
for _ in range(num_warmup_passes):
|
||||
i = g_spin % num_buffers
|
||||
g_spin += 1
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, 64)
|
||||
|
||||
|
||||
def warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis):
|
||||
num_buffers = len(a)
|
||||
num_exp = ptrs_trellis[0].shape[0]
|
||||
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
|
||||
K = b[0][0].shape[-1] // 16
|
||||
for i_ in range(num_warmup_passes):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_mgemm(
|
||||
a[i],
|
||||
ptrs_trellis[i],
|
||||
c[i],
|
||||
ptrs_suh[i],
|
||||
a[i],
|
||||
ptrs_svh[i],
|
||||
m_indices,
|
||||
None,
|
||||
K,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
0
|
||||
)
|
||||
|
||||
|
||||
def benchmark(a, b, c, suh, svh, shape_idx, num_iter, num_sms):
|
||||
num_buffers = len(a)
|
||||
dummy = c[0][0, 0].item()
|
||||
with Timer() as t:
|
||||
for i_ in range(num_iter):
|
||||
i = i_ % num_buffers
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
|
||||
dummy = c[i][0, 0].item()
|
||||
mean_time_ms = t.interval / num_iter * 1000
|
||||
return mean_time_ms
|
||||
|
||||
|
||||
def benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_iter, num_sms, trim = outlier_trim, stream = None):
|
||||
global g_spin
|
||||
device = a[0].device
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream(device)
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Precreate events to reduce overhead jitter
|
||||
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
|
||||
# Timed loop
|
||||
for it in range(num_iter):
|
||||
i = g_spin % len(a)
|
||||
g_spin += 1
|
||||
starts[it].record(stream)
|
||||
ext.exl3_gemm(a[i], b[i], c[i], suh[i], a[i], svh[i], shape_idx, 0, 0, num_sms)
|
||||
stops[it].record(stream)
|
||||
|
||||
# Ensure all recorded events are complete before reading times
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Collect per-iteration latency in milliseconds (GPU time)
|
||||
per_ms = np.array([starts[it].elapsed_time(stops[it]) for it in range(num_iter)], dtype = np.float64)
|
||||
|
||||
# Robust stats
|
||||
median = float(np.median(per_ms))
|
||||
mean = float(per_ms.mean())
|
||||
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
|
||||
|
||||
# Simple symmetric trimming
|
||||
if 0.0 < trim < 0.5 and num_iter > 4:
|
||||
lo = np.quantile(per_ms, trim)
|
||||
hi = np.quantile(per_ms, 1.0 - trim)
|
||||
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
|
||||
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
|
||||
else:
|
||||
trimmed = per_ms
|
||||
trimmed_mean = mean
|
||||
|
||||
return {
|
||||
"per_launch_ms": per_ms, # numpy array of length num_iter
|
||||
"mean_ms": mean,
|
||||
"median_ms": median,
|
||||
"std_ms": std,
|
||||
"trimmed_mean_ms": trimmed_mean,
|
||||
"trim_bounds_ms": (
|
||||
float(lo) if 'lo' in locals() else None,
|
||||
float(hi) if 'hi' in locals() else None
|
||||
),
|
||||
"kept_count": int(trimmed.size),
|
||||
"total_count": int(num_iter),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_iter, num_sms, trim = outlier_trim, stream = None):
|
||||
device = a[0].device
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream(device)
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
num_exp = ptrs_trellis[0].shape[0]
|
||||
m_indices = torch.arange(0, num_exp, dtype = torch.long, device = a[0].device).unsqueeze(0)
|
||||
K = b[0][0].shape[-1] // 16
|
||||
|
||||
# Precreate events to reduce overhead jitter
|
||||
starts = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
stops = [torch.cuda.Event(enable_timing = True) for _ in range(num_iter)]
|
||||
|
||||
# Timed loop
|
||||
for it in range(num_iter):
|
||||
i = it % len(a)
|
||||
starts[it].record(stream)
|
||||
for _ in range(2):
|
||||
ext.exl3_mgemm(
|
||||
a[i],
|
||||
ptrs_trellis[i],
|
||||
c[i],
|
||||
ptrs_suh[i],
|
||||
a[i],
|
||||
ptrs_svh[i],
|
||||
m_indices,
|
||||
None,
|
||||
K,
|
||||
shape_idx,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
num_sms
|
||||
)
|
||||
stops[it].record(stream)
|
||||
|
||||
# Ensure all recorded events are complete before reading times
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Collect per-iteration latency in milliseconds (GPU time)
|
||||
per_ms = np.array([starts[it].elapsed_time(stops[it]) / 2 for it in range(num_iter)], dtype = np.float64)
|
||||
|
||||
# Robust stats
|
||||
median = float(np.median(per_ms))
|
||||
mean = float(per_ms.mean())
|
||||
std = float(per_ms.std(ddof=1)) if num_iter > 1 else 0.0
|
||||
|
||||
# Simple symmetric trimming
|
||||
if 0.0 < trim < 0.5 and num_iter > 4:
|
||||
lo = np.quantile(per_ms, trim)
|
||||
hi = np.quantile(per_ms, 1.0 - trim)
|
||||
trimmed = per_ms[(per_ms >= lo) & (per_ms <= hi)]
|
||||
trimmed_mean = float(trimmed.mean()) if trimmed.size else float('nan')
|
||||
else:
|
||||
trimmed = per_ms
|
||||
trimmed_mean = mean
|
||||
|
||||
return {
|
||||
"per_launch_ms": per_ms, # numpy array of length num_iter
|
||||
"mean_ms": mean,
|
||||
"median_ms": median,
|
||||
"std_ms": std,
|
||||
"trimmed_mean_ms": trimmed_mean,
|
||||
"trim_bounds_ms": (
|
||||
float(lo) if 'lo' in locals() else None,
|
||||
float(hi) if 'hi' in locals() else None
|
||||
),
|
||||
"kept_count": int(trimmed.size),
|
||||
"total_count": int(num_iter),
|
||||
}
|
||||
|
||||
def get_floor(hist_sms):
|
||||
best = float("inf"), 0, 0
|
||||
for l, s, i in hist_sms:
|
||||
if l < best[0]:
|
||||
best = l, s, i
|
||||
for l, s, i in hist_sms:
|
||||
if l < best[0] * 1.008:
|
||||
return l, s, i
|
||||
return best
|
||||
|
||||
|
||||
def test_latencies(mod, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices):
|
||||
|
||||
g_hist = []
|
||||
|
||||
for shape_idx in shape_indices:
|
||||
|
||||
max_sms = ext.g_get_num_sms(device_idx)
|
||||
|
||||
warmup(a, b, c, suh, svh, shape_idx)
|
||||
|
||||
hist_sms = []
|
||||
for num_sms in range(0, max_sms + 1, 8):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
|
||||
num_sms_ = max(num_sms, 1)
|
||||
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, num_sms_)
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
hist_sms.append((mean, num_sms_, shape_idx))
|
||||
|
||||
_, best_sms, _ = get_floor(hist_sms)
|
||||
|
||||
sms_a = best_sms - 8
|
||||
sms_b = best_sms + 8
|
||||
if sms_a < 0:
|
||||
sms_a = 0
|
||||
sms_b = 16
|
||||
if sms_b > num_sms + 4:
|
||||
sms_b = num_sms + 4
|
||||
sms_a = sms_b - 16
|
||||
|
||||
warmup(a, b, c, suh, svh, shape_idx)
|
||||
|
||||
for num_sms in range(sms_a, sms_b + 1, 2):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
num_sms_ = min(max(num_sms, 1), max_sms)
|
||||
stats = benchmark_per_launch(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, num_sms_)
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
g_hist.append((mean, num_sms_, shape_idx))
|
||||
|
||||
g_hist = sorted(g_hist, key = lambda t: t[1])
|
||||
best_g_lat, best_g_sms, best_g_idx = get_floor(g_hist)
|
||||
|
||||
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}")
|
||||
return best_g_lat, best_g_sms, best_g_idx
|
||||
|
||||
|
||||
def test_latencies_m(mod, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices, bszm_in, bszm_out):
|
||||
|
||||
best_g_lat = float("inf")
|
||||
best_g_sms = None
|
||||
best_g_idx = None
|
||||
|
||||
best_l_lat = float("inf")
|
||||
best_l_sms = None
|
||||
best_l_idx = None
|
||||
|
||||
for shape_idx in shape_indices:
|
||||
|
||||
# Skip incompatible shapes
|
||||
# if not ext.exl3_gemm_shape_compat(shape_idx, m, k, n, K):
|
||||
# continue
|
||||
|
||||
max_sms = ext.g_get_num_sms(device_idx)
|
||||
|
||||
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
|
||||
|
||||
best_sms_lat = float("inf")
|
||||
for num_sms in range(0, max_sms + 1, 8):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_a, max(num_sms, 1))
|
||||
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_a, max(num_sms, 1))
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
if mean < best_sms_lat:
|
||||
best_sms_lat = mean
|
||||
best_sms = max(num_sms, 1)
|
||||
|
||||
sms_a = best_sms - 8
|
||||
sms_b = best_sms + 8
|
||||
if sms_a < 0:
|
||||
sms_a = 0
|
||||
sms_b = 16
|
||||
if sms_b > num_sms + 4:
|
||||
sms_b = num_sms + 4
|
||||
sms_a = sms_b - 16
|
||||
|
||||
warmup_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis)
|
||||
|
||||
for num_sms in range(sms_a, sms_b + 1, 2):
|
||||
# mean = benchmark(a, b, c, suh, svh, shape_idx, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
stats = benchmark_per_launch_m(a, b, c, suh, svh, shape_idx, ptrs_suh, ptrs_svh, ptrs_trellis, num_benchmark_iter_b, min(max(num_sms, 1), max_sms))
|
||||
mean = stats["trimmed_mean_ms"]
|
||||
if mean < best_g_lat:
|
||||
best_g_lat = mean
|
||||
best_g_sms = min(max(num_sms, 1), max_sms)
|
||||
best_g_idx = shape_idx
|
||||
if mean < best_l_lat:
|
||||
best_l_lat = mean
|
||||
best_l_sms = min(max(num_sms, 1), max_sms)
|
||||
best_l_idx = shape_idx
|
||||
|
||||
print(f"mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_l_lat:8.5f} ms, num_sms {best_l_sms:3}, shape_idx {best_l_idx}, i/o {bszm_in}/{bszm_out}")
|
||||
|
||||
print()
|
||||
print(f" --> mod {mod}, dev {device_idx}, m {m:6}, k {k:6}, n {n:6}, K {K}, mean {best_g_lat:8.5f} ms, num_sms {best_g_sms:3}, shape_idx {best_g_idx}, i/o {bszm_in}/{bszm_out}")
|
||||
print()
|
||||
return best_g_lat, best_g_sms, best_g_idx
|
||||
|
||||
|
||||
def tune_shape(K, m, k, n, device):
|
||||
|
||||
a, b, c, suh, svh = get_abc(K, m, k, n, device)
|
||||
device_idx = torch.device(device).index
|
||||
cc = ext.g_get_cc(device_idx)
|
||||
|
||||
res = {
|
||||
"K": K,
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"cc": cc,
|
||||
"bszm_in": 1,
|
||||
"bszm_out": 1,
|
||||
}
|
||||
res_128 = None
|
||||
res_256 = None
|
||||
res_512 = None
|
||||
|
||||
if True:
|
||||
lat_128, sms_128, idx_128 = test_latencies(128, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_128)
|
||||
res_128 = res.copy()
|
||||
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
|
||||
|
||||
if n % 256 == 0:
|
||||
lat_256, sms_256, idx_256 = test_latencies(256, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_256)
|
||||
if lat_128 < lat_256:
|
||||
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
|
||||
res_256 = res.copy()
|
||||
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
|
||||
|
||||
if n % 512 == 0:
|
||||
lat_512, sms_512, idx_512 = test_latencies(512, K, m, k, n, a, b, c, suh, svh, device_idx, shape_indices_512)
|
||||
if lat_128 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
|
||||
if lat_256 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
|
||||
res_512 = res.copy()
|
||||
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
|
||||
|
||||
return res_128, res_256, res_512
|
||||
|
||||
|
||||
def tune_shape_m(K, m, k, n, device, bszm_in, bszm_out):
|
||||
|
||||
a, b, c, suh, svh, ptrs_suh, ptrs_svh, ptrs_trellis = get_abc_m(K, m, k, n, device, bszm_in, bszm_out)
|
||||
device_idx = torch.device(device).index
|
||||
cc = ext.g_get_cc(device_idx)
|
||||
|
||||
res = {
|
||||
"K": K,
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"cc": cc,
|
||||
"bszm_in": 1,
|
||||
"bszm_out": 1,
|
||||
}
|
||||
res_128 = None
|
||||
res_256 = None
|
||||
res_512 = None
|
||||
|
||||
if True:
|
||||
lat_128, sms_128, idx_128 = test_latencies_m(128, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_128, bszm_in, bszm_out)
|
||||
res_128 = res.copy()
|
||||
res_128.update({"lat": lat_128, "sms": sms_128, "idx": idx_128})
|
||||
|
||||
if n % 256 == 0:
|
||||
lat_256, sms_256, idx_256 = test_latencies_m(256, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_256, bszm_in, bszm_out)
|
||||
if lat_128 < lat_256:
|
||||
lat_256, sms_256, idx_256 = lat_128, sms_128, idx_128
|
||||
res_256 = res.copy()
|
||||
res_256.update({"lat": lat_256, "sms": sms_256, "idx": idx_256})
|
||||
|
||||
if n % 512 == 0:
|
||||
lat_512, sms_512, idx_512 = test_latencies_m(512, K, m, k, n, a, b, c, suh, svh, device_idx, ptrs_suh, ptrs_svh, ptrs_trellis, shape_indices_512, bszm_in, bszm_out)
|
||||
if lat_128 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_128, sms_128, idx_128
|
||||
if lat_256 < lat_512:
|
||||
lat_512, sms_512, idx_512 = lat_256, sms_256, idx_256
|
||||
res_512 = res.copy()
|
||||
res_512.update({"lat": lat_512, "sms": sms_512, "idx": idx_512})
|
||||
|
||||
return res_128, res_256, res_512
|
||||
|
||||
|
||||
def tune_gemm():
|
||||
out_128 = "struct TSample samples_128[] =\n{\n"
|
||||
out_256 = "struct TSample samples_256[] =\n{\n"
|
||||
out_512 = "struct TSample samples_512[] =\n{\n"
|
||||
for device in devices:
|
||||
for m in shapes_m:
|
||||
for k in shapes_k:
|
||||
for n in shapes_n:
|
||||
for ki, K in enumerate(Ks):
|
||||
res_128, res_256, res_512 = tune_shape(K, m, k, n, device)
|
||||
r = res_128
|
||||
if r:
|
||||
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
r = res_256
|
||||
if r:
|
||||
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
r = res_512
|
||||
if r:
|
||||
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']} }},\n"
|
||||
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
print(out_128)
|
||||
print()
|
||||
print(out_256)
|
||||
print()
|
||||
print(out_512)
|
||||
print()
|
||||
|
||||
|
||||
def tune_mgemm():
|
||||
out_128 = "struct TMSample msamples_128[] =\n{\n"
|
||||
out_256 = "struct TMSample msamples_256[] =\n{\n"
|
||||
out_512 = "struct TMSample msamples_512[] =\n{\n"
|
||||
for device in devices:
|
||||
for m in shapes_m:
|
||||
for k in shapes_k:
|
||||
for n in shapes_n:
|
||||
for ki, K in enumerate(Ks):
|
||||
for (bszm_in, bszm_out) in mgemm_bszm_io:
|
||||
res_128, res_256, res_512 = tune_shape_m(K, m, k, n, device, bszm_in, bszm_out)
|
||||
r = res_128
|
||||
if r:
|
||||
out_128 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
r = res_256
|
||||
if r:
|
||||
out_256 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
r = res_512
|
||||
if r:
|
||||
out_512 += f" {{ {r['cc']}, {r['K']}, {r['m']}, {r['k']}, {r['n']}, {r['idx']}, {r['sms']}, {r['bszm_in']}, {r['bszm_out']} }},\n"
|
||||
out_128 = out_128 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_256 = out_256 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
out_512 = out_512 + " { 0, 0, 0, 0, 0, 0, 0, 0, 0 }\n};"
|
||||
print(out_128)
|
||||
print()
|
||||
print(out_256)
|
||||
print()
|
||||
print(out_512)
|
||||
print()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
tune_gemm()
|
||||
# tune_mgemm()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user