update f16xMXF4

This commit is contained in:
Feng Shijie
2025-08-13 16:16:48 +00:00
parent 732ebdee8b
commit 5de6208952
6 changed files with 113 additions and 48 deletions

View File

@@ -97,6 +97,8 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -129,9 +131,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
@@ -211,10 +214,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
}
else
{
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -412,17 +415,17 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
{
if(persistent_opt == 0)
{
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// false>(argc, argv, Row{}, Col{}, Row{});
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
else
{
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// true>(argc, argv, Row{}, Col{}, Row{});
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else if(mixed_prec == "fp16xfp4")
@@ -434,13 +437,13 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
// else
// {
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// true>(argc, argv, Row{}, Col{}, Row{});
// }
else
{
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else
{
@@ -466,10 +469,10 @@ int main(int argc, char* argv[])
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
// }
else if(warp_tile == 1)
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
}
else
{
throw std::runtime_error("Unsupported warp_tile!");

View File

@@ -58,8 +58,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
@@ -165,8 +165,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float rtol = 5e-3;
const float atol = 1e-3;
const float rtol = 1e-2;
const float atol = 1e-2;
pass = ck_tile::check_err(
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);