multi instance generation for CkTileEngine (#2080)

* Add support for multi-instance verification, print detail for each instance, documentation fix

* clang formatted

* Added Readme file

* updated readme

* Addressing review comments

* clang formatted

* Updated ReadMe and GPU reference code

* simplified dispatch kernel code

* indentation

[ROCm/composable_kernel commit: 7cadf187e2]
This commit is contained in:
Khushbu Agarwal
2025-04-21 08:39:45 -07:00
committed by GitHub
parent dd2c3289c9
commit 74210a9dfc
5 changed files with 202 additions and 140 deletions

View File

@@ -6,11 +6,16 @@
#include "gemm_dispatcher.hpp"
#include "gemm_host_api.hpp"
float gemm_kernel_launch(KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
return GemmDispatcher::dispatch(trait, args, s);
return GemmDispatcher::dispatch(
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
}
template <typename ADataType,
@@ -20,11 +25,10 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool run(const ck_tile::ArgParser& arg_parser)
void run(const ck_tile::ArgParser& arg_parser)
{
const ALayout a_layout = ALayout{};
const BLayout b_layout = BLayout{};
// const CLayout c_layout = CLayout{};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t M = arg_parser.get_int("m");
@@ -113,43 +117,47 @@ bool run(const ck_tile::ArgParser& arg_parser)
trait.kPadN = arg_parser.get_bool("pad_n");
trait.kPadK = arg_parser.get_bool("pad_k");
float ave_time = gemm_kernel_launch(
trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(verify)
{
pass = gemm_verify<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
verify,
a_m_k,
b_k_n,
c_m_n_dev_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch);
gemm_host_reference<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(verify,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C);
}
return pass;
gemm_kernel_launch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return;
}
int main(int argc, char* argv[])
@@ -159,7 +167,8 @@ int main(int argc, char* argv[])
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
return run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
return 0;
}
catch(const std::exception& e)
{