mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
multi instance generation for CkTileEngine (#2080)
* Add support for multi-instance verification, print detail for each instance, documentation fix * clang formatted * Added Readme file * updated readme * Addressing review comments * clang formatted * Updated ReadMe and GPU reference code * simplified dispatch kernel code * indentation
This commit is contained in:
@@ -447,6 +447,17 @@ struct GemmKernel {{
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
static std::string get_name() {{
|
||||
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
|
||||
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
|
||||
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
|
||||
"PadidngM: " + "{kPadM}" + ", " +
|
||||
"PaddingN: " + "{kPadN}" + ", " +
|
||||
"PaddingK: " + "{kPadK}" + ", " +
|
||||
"Pipeline: " + "{pipeline}" + ", " +
|
||||
"Epilogue: " + "{epilogue}" + ", " +
|
||||
"Scheduler: " + "{scheduler}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
@@ -476,7 +487,10 @@ struct GemmDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<std::string,
|
||||
std::function<float(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
@@ -499,9 +513,12 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
std::vector<float> results;"""
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
@@ -509,21 +526,46 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
//we can have multiple tiles config for the one kernel_trait
|
||||
return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);"""
|
||||
content += """
|
||||
};\n"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
|
||||
static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
template <typename Kernel>
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
if(verify)
|
||||
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(gemm_args, s); //Running single instance
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user