merge flatmm -scale

This commit is contained in:
lalala-sh
2025-07-24 08:46:51 +00:00
8 changed files with 1265 additions and 654 deletions

View File

@@ -11,9 +11,9 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
#include <type_traits>
template <typename T>
constexpr const char* DataTypeToString()
{
@@ -29,7 +29,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
@@ -84,70 +84,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
@@ -159,9 +95,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
@@ -217,13 +156,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV0<CodegenPipelineProblem>;
@@ -340,6 +279,79 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -360,6 +372,7 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
@@ -367,6 +380,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
#include "run_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_flatmm_example(int argc, char* argv[])
{
@@ -380,6 +395,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
@@ -394,13 +410,29 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else
{