support flatmm scaling

This commit is contained in:
Feng Shijie
2025-07-23 19:04:22 +00:00
parent 3f7d848dd3
commit 5a1183ebbd
7 changed files with 476 additions and 318 deletions

View File

@@ -23,9 +23,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
@@ -81,13 +84,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
@@ -217,6 +220,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
@@ -231,13 +235,29 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else
{

View File

@@ -83,10 +83,10 @@ struct FlatmmConfig16
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 1;
};
template <typename ADataType>
@@ -167,120 +167,6 @@ struct is_8bit_type
{
};
template <typename DataType>
struct GemmConfig
{
#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -301,6 +187,7 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
@@ -83,9 +83,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);
template <typename FlatmmConfig,
typename ADataType,
@@ -97,6 +100,8 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
@@ -108,21 +113,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
@@ -134,6 +143,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
@@ -154,6 +165,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -197,21 +210,30 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
// TODO: add different init types
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else
{
@@ -222,14 +244,25 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
c_rslt_host.SetZero();
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
@@ -239,18 +272,22 @@ int run_flatmm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
@@ -263,6 +300,8 @@ int run_flatmm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1)
{
assert(ScaleGranularityM == -1 && ScaleGranularityN == -1 &&
"ScaleAB is not supported for CPU verification!");
ck_tile::HostTensor<CDataType> c_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref_host.SetZero();
@@ -310,13 +349,41 @@ int run_flatmm_example_with_layouts(int argc,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
d_C,