This commit is contained in:
yadai
2025-10-20 04:11:37 +00:00
parent 10a288c3a2
commit 5962722a1a
3 changed files with 78 additions and 64 deletions

View File

@@ -564,7 +564,7 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
SET(BUILD_DEV OFF CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)

View File

@@ -13,6 +13,56 @@
#include "flatmm_basic.hpp"
#include <type_traits>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "128", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename FlatmmConfig, typename T>
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp,
FlatmmConfig::N_Warp_Tile,
NRepeat,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
#include "run_flatmm_example.inc"
/*
template <typename T>
constexpr const char* DataTypeToString()
{
@@ -62,25 +112,6 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename FlatmmConfig, typename T>
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp,
FlatmmConfig::N_Warp_Tile,
NRepeat,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -101,6 +132,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
*/
template <typename FlatmmConfig,
typename ADataType,
@@ -299,6 +331,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
return ave_time;
}
/*
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
@@ -372,37 +405,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
return ave_time;
}
*/
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "128", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "run_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_flatmm_example(int argc, char* argv[])

View File

@@ -97,7 +97,10 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
typename ScaleM,
typename ScaleN,
bool UsePersistentKernel = false,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
@@ -108,21 +111,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
@@ -134,7 +141,9 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
BLayout,
DsLayout,
CLayout,
false,
ScaleM,
ScaleN,
UsePersistentKernel,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
@@ -152,6 +161,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
return ave_time;
}
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,