update example code

This commit is contained in:
Sami Remes
2026-01-27 12:57:04 -05:00
parent f62cc5415f
commit 08ec1f4192
3 changed files with 11 additions and 43 deletions

View File

@@ -60,41 +60,6 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
scale_m,
scale_n);
// using GemmShape = ck_tile::TileGemmShape<
// ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
// ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
// ck_tile::sequence<GemmConfig::M_Warp_Tile,
// GemmConfig::N_Warp_Tile,
// GemmConfig::K_Warp_Tile>>;
// using TilePartitioner =
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
// GemmConfig::TileParitionerGroupNum,
// GemmConfig::TileParitionerM01>;
// using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
// GemmConfig::kPadN,
// GemmConfig::kPadK,
// GemmConfig::DoubleSmemBuffer,
// ALayout,
// BLayout,
// CLayout,
// GemmConfig::TransposeC,
// GemmConfig::UseStructuredSparsity,
// UsePersistentKernel,
// GemmConfig::NumWaveGroups,
// false>;
// using MXPipelineProblem = MXGemmPipelineProblem<ADataType,
// BDataType,
// AccDataType,
// GemmShape,
// MXGemmTraits,
// GemmConfig::Scheduler>;
// // Use the new MX comp_async pipeline with MX scaling support
// using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// Simplified invocation - comp_async handles hot loop and tail internally
auto invoke_splitk_path = [&](auto split_k_) {
return mx_gemm_calc<GemmConfig,
@@ -154,10 +119,7 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
.insert("init", "0", "0:random, 1:constant(1)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}