merge flatmm -scale

This commit is contained in:
lalala-sh
2025-07-24 08:46:51 +00:00
8 changed files with 1265 additions and 654 deletions

View File

@@ -11,9 +11,9 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
#include <type_traits>
template <typename T>
constexpr const char* DataTypeToString()
{
@@ -29,7 +29,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
@@ -84,70 +84,6 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
@@ -159,9 +95,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
@@ -217,13 +156,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV0<CodegenPipelineProblem>;
@@ -340,6 +279,79 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -360,6 +372,7 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
@@ -367,6 +380,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
#include "run_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_flatmm_example(int argc, char* argv[])
{
@@ -380,6 +395,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
@@ -394,13 +410,29 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else
{

View File

@@ -83,10 +83,10 @@ struct FlatmmConfig16
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 2;
static constexpr int kBlockPerCu = 1;
};
template <typename ADataType>
@@ -177,6 +177,9 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -4,6 +4,8 @@
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -47,21 +49,30 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
// TODO: add different init types
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else
{
@@ -72,14 +83,25 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
c_rslt_host.SetZero();
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
@@ -89,24 +111,30 @@ int run_flatmm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
if (ScaleGranularityM == -1 || ScaleGranularityN == -1)
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
ck_tile::HostTensor<CDataType> c_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref_host.SetZero();
@@ -154,13 +182,41 @@ int run_flatmm_example_with_layouts(int argc,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
d_C,
@@ -189,4 +245,4 @@ int run_flatmm_example_with_layouts(int argc,
}
return pass;
}
}