support flatmm scaling

This commit is contained in:
Feng Shijie
2025-07-23 19:04:22 +00:00
parent 3f7d848dd3
commit 5a1183ebbd
7 changed files with 476 additions and 318 deletions

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
@@ -83,9 +83,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);
template <typename FlatmmConfig,
typename ADataType,
@@ -97,6 +100,8 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
@@ -108,21 +113,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
@@ -134,6 +143,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
@@ -154,6 +165,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -197,21 +210,30 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
// TODO: add different init types
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else
{
@@ -222,14 +244,25 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
c_rslt_host.SetZero();
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
@@ -239,18 +272,22 @@ int run_flatmm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
@@ -263,6 +300,8 @@ int run_flatmm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1)
{
assert(ScaleGranularityM == -1 && ScaleGranularityN == -1 &&
"ScaleAB is not supported for CPU verification!");
ck_tile::HostTensor<CDataType> c_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref_host.SetZero();
@@ -310,13 +349,41 @@ int run_flatmm_example_with_layouts(int argc,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
d_C,