mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
use new pipeline in example
This commit is contained in:
@@ -90,45 +90,31 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler,
|
||||
true, // HasHotLoop
|
||||
ck_tile::TailNumber::Full>;
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
|
||||
// Use the new comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = MXGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = MXGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time = MXGemmPipeline::template TailHandler<true>(
|
||||
[&](auto has_hot_loop_, auto) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = ck_tile::TailNumber::Full;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
split_k_.value,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
split_k_.value>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
|
||||
float ave_time = (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
Reference in New Issue
Block a user