mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
fix the buffer intrinsic names for clang >=20 (#2228)
This commit is contained in:
@@ -4,14 +4,22 @@
|
||||
#include <type_traits>
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString() {
|
||||
if constexpr (std::is_same_v<T, ck_tile::half_t>) {
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
} else if constexpr (std::is_same_v<T, ck_tile::fp8_t>) {
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
} else if constexpr (std::is_same_v<T, ck_tile::bf8_t>) {
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
@@ -112,8 +120,9 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -121,18 +130,15 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>() << " M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -147,7 +153,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
@@ -182,7 +188,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
|
||||
// do pre-shuffle
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
|
||||
ck_tile::index_t mfma_type = 1;
|
||||
#else
|
||||
@@ -193,18 +199,18 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
bool pass = true;
|
||||
@@ -219,8 +225,9 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
a_host, b_origin_host, c_ref_host);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
@@ -277,8 +284,9 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_gpu_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
|
||||
Reference in New Issue
Block a user