This commit is contained in:
Feng Shijie
2025-07-28 08:24:51 +00:00
parent 5473f06461
commit 1b6d7cf407
4 changed files with 163 additions and 97 deletions

View File

@@ -271,10 +271,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
// ave_time =
// ck_tile::launch_kernel(s,
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
// Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -289,10 +289,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -420,14 +420,14 @@ int run_flatmm_example(int argc, char* argv[])
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
// argc, argv, Row{}, Col{}, Row{});
}
// else if(data_type == "bf16")
// {
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
else if(data_type == "fp8")
{
if(scale_opt == 0)
@@ -441,19 +441,20 @@ int run_flatmm_example(int argc, char* argv[])
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
// else if(data_type == "bf8")
// {
// if(scale_opt == 0)
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
// }
// }
else
{
throw std::runtime_error("Unsupported data_type!");
@@ -479,18 +480,18 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
// }
// else if(warp_tile == 2)
// {
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
// }
// else
// {
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
// }
}
catch(const std::runtime_error& e)
{

View File

@@ -56,6 +56,7 @@ int run_flatmm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
memset(a_host.data(), 0, 4);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
@@ -111,33 +112,6 @@ int run_flatmm_example_with_layouts(int argc,
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
@@ -236,6 +210,34 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float max_accumulated_value =
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());