multi instance generation for CkTileEngine (#2080)

* Add support for multi-instance verification, print detail for each instance, documentation fix

* clang formatted

* Added Readme file

* updated readme

* Addressing review comments

* clang formatted

* Updated ReadMe and GPU reference code

* simplified dispatch kernel code

* indentation
This commit is contained in:
Khushbu Agarwal
2025-04-21 08:39:45 -07:00
committed by GitHub
parent c318ec0778
commit 7cadf187e2
5 changed files with 202 additions and 140 deletions

View File

@@ -447,6 +447,17 @@ struct GemmKernel {{
return ave_time;
}}
static std::string get_name() {{
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
"PadidngM: " + "{kPadM}" + ", " +
"PaddingN: " + "{kPadN}" + ", " +
"PaddingK: " + "{kPadK}" + ", " +
"Pipeline: " + "{pipeline}" + ", " +
"Epilogue: " + "{epilogue}" + ", " +
"Scheduler: " + "{scheduler}";
}}
}};
"""
@@ -476,7 +487,10 @@ struct GemmDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<std::string,
std::function<float(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
return kernel_map;
}
@@ -499,9 +513,12 @@ struct GemmDispatcher {
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
std::vector<float> results;"""
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
@@ -509,21 +526,46 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
//we can have multiple tiles config for the one kernel_trait
return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);"""
content += """
};\n"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
content += f"""
}};\n"""
content += """ }
static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
template <typename Kernel>
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
float avg_time = Kernel::launch(args, s);
std::string description = Kernel::get_name();
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
if(verify)
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& s) {
init();
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
return it->second(gemm_args, s); //Running single instance
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
}
throw std::runtime_error("No suitable kernel found: " + key);
}