multi instance generation for CkTileEngine (#2080)

* Add support for multi-instance verification, print detail for each instance, documentation fix

* clang formatted

* Added Readme file

* updated readme

* Addressing review comments

* clang formatted

* Updated ReadMe and GPU reference code

* simplified dispatch kernel code

* indentation

[ROCm/composable_kernel commit: 7cadf187e2]
This commit is contained in:
Khushbu Agarwal
2025-04-21 08:39:45 -07:00
committed by GitHub
parent dd2c3289c9
commit 74210a9dfc
5 changed files with 202 additions and 140 deletions

View File

@@ -0,0 +1,51 @@
# GEMM Matrix Multiplication
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`.
# Kernel Configurations
User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json`
## Build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# To generate the executable
make tile_engine_gemm -j
```
`tile_engine_gemm` will be located in the `./bin/` directory.
## tile_engine_gemm inputs
```
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-split_k SplitK value (default:1)
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
-warmup Number of iterations before benchmark the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
-pipeline possible values are: compv3, compv4, mem (default:compv3)
-scheduler possible values are: intrawave, interwave (default:intrawave)
-epilogue possible values are: cshuffle, default (default:cshuffle)
-pad_m Pad in m direction - true/false (default:false)
-pad_n Pad in n direction - true/false (default:false)
-pad_k Pad in k direction - true/false (default:false)
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
```
## Example
Below example will run gemm kernel with default dimensions of matrices, for compv3 pipeline, intrawave scheduler and default epilogue with all possible tile sizes mentioned in Config file.
```
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
```

View File

@@ -19,7 +19,7 @@
"values": [256]
},
"tile_k": {
"values": [64]
"values": [64, 32]
},
"warp_m": {
"values": [2]

View File

@@ -6,11 +6,16 @@
#include "gemm_dispatcher.hpp"
#include "gemm_host_api.hpp"
float gemm_kernel_launch(KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
return GemmDispatcher::dispatch(trait, args, s);
return GemmDispatcher::dispatch(
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
}
template <typename ADataType,
@@ -20,11 +25,10 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool run(const ck_tile::ArgParser& arg_parser)
void run(const ck_tile::ArgParser& arg_parser)
{
const ALayout a_layout = ALayout{};
const BLayout b_layout = BLayout{};
// const CLayout c_layout = CLayout{};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t M = arg_parser.get_int("m");
@@ -113,43 +117,47 @@ bool run(const ck_tile::ArgParser& arg_parser)
trait.kPadN = arg_parser.get_bool("pad_n");
trait.kPadK = arg_parser.get_bool("pad_k");
float ave_time = gemm_kernel_launch(
trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(verify)
{
pass = gemm_verify<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
verify,
a_m_k,
b_k_n,
c_m_n_dev_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch);
gemm_host_reference<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(verify,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C);
}
return pass;
gemm_kernel_launch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return;
}
int main(int argc, char* argv[])
@@ -159,7 +167,8 @@ int main(int argc, char* argv[])
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
return run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
return 0;
}
catch(const std::exception& e)
{

View File

@@ -1,3 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
@@ -54,24 +57,21 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
static constexpr const char* name = "pk_int4_t";
};
/**
* @brief trait for GEMM kernel
* @param pipeline: pipeline name
* @param scheduler: scheduler name
* @param epilogue: epilogue name
* @param kPadM: padding for M dimension
* @param kPadN: padding for N dimension
* @param kPadK: padding for K dimension
*
*/
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
/// specific kernel instance based on the provided settings.
struct KernelTraits
{
/// @brief The name of the pipeline.
std::string pipeline;
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
std::string scheduler;
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool kPadM;
/// @brief Indicates whether padding is applied to the N dimension.
bool kPadN;
/// @brief Indicates whether padding is applied to the K dimension.
bool kPadK;
};
@@ -184,11 +184,28 @@ void permute_vectors_i4x4_b(Tensor& tensor)
}
}
/**
* @brief Function to verify the kernel output with reference implementation on CPU/GPU
*
*/
/// @brief Function to compare the results of the device and host computations
void compare(ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -196,43 +213,25 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool gemm_verify(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch)
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
bool pass = true;
if(verify == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
c_m_n_host_result.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
a_m_k, b_k_n, c_m_n_host_result);
}
else if(verify == 2)
{
@@ -241,29 +240,14 @@ bool gemm_verify(int verify,
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
c_m_n_host_result.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
a_m_k.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
b_k_n.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
@@ -273,30 +257,6 @@ bool gemm_verify(int verify,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
c_m_n_dev_result.get_element_space_size_in_bytes(),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
}
return pass;
}

View File

@@ -447,6 +447,17 @@ struct GemmKernel {{
return ave_time;
}}
static std::string get_name() {{
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
"PadidngM: " + "{kPadM}" + ", " +
"PaddingN: " + "{kPadN}" + ", " +
"PaddingK: " + "{kPadK}" + ", " +
"Pipeline: " + "{pipeline}" + ", " +
"Epilogue: " + "{epilogue}" + ", " +
"Scheduler: " + "{scheduler}";
}}
}};
"""
@@ -476,7 +487,10 @@ struct GemmDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<std::string,
std::function<float(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
return kernel_map;
}
@@ -499,9 +513,12 @@ struct GemmDispatcher {
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
std::vector<float> results;"""
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
@@ -509,21 +526,46 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
//we can have multiple tiles config for the one kernel_trait
return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);"""
content += """
};\n"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
content += f"""
}};\n"""
content += """ }
static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
template <typename Kernel>
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
float avg_time = Kernel::launch(args, s);
std::string description = Kernel::get_name();
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
if(verify)
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& s) {
init();
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
return it->second(gemm_args, s); //Running single instance
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
}
throw std::runtime_error("No suitable kernel found: " + key);
}