use new pipeline in example

This commit is contained in:
Sami Remes
2026-01-13 09:25:13 -05:00
parent edd11c9852
commit 93ff8b07a2
3 changed files with 33 additions and 55 deletions

View File

@@ -90,45 +90,31 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
AccDataType,
GemmShape,
MXGemmTraits,
GemmConfig::Scheduler,
true, // HasHotLoop
ck_tile::TailNumber::Full>;
GemmConfig::Scheduler>;
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
// Use the new comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
const bool has_hot_loop = MXGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = MXGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time = MXGemmPipeline::template TailHandler<true>(
[&](auto has_hot_loop_, auto) {
constexpr auto has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_num_v = ck_tile::TailNumber::Full;
auto invoke_splitk_path = [&](auto split_k_) {
return mx_gemm_calc<GemmConfig,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleM,
ScaleN,
UsePersistentKernel,
split_k_.value,
has_hot_loop_v,
tail_num_v>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
};
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
: invoke_splitk_path(std::true_type{});
},
has_hot_loop,
tail_num);
// Simplified invocation - comp_async handles hot loop and tail internally
auto invoke_splitk_path = [&](auto split_k_) {
return mx_gemm_calc<GemmConfig,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleM,
ScaleN,
UsePersistentKernel,
split_k_.value>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
};
float ave_time = (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
: invoke_splitk_path(std::true_type{});
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;