use new pipeline in example

This commit is contained in:
Sami Remes
2026-01-13 09:25:13 -05:00
parent edd11c9852
commit 93ff8b07a2
3 changed files with 33 additions and 55 deletions

View File

@@ -90,45 +90,31 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
AccDataType,
GemmShape,
MXGemmTraits,
GemmConfig::Scheduler,
true, // HasHotLoop
ck_tile::TailNumber::Full>;
GemmConfig::Scheduler>;
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
// Use the new comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
const bool has_hot_loop = MXGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = MXGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time = MXGemmPipeline::template TailHandler<true>(
[&](auto has_hot_loop_, auto) {
constexpr auto has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_num_v = ck_tile::TailNumber::Full;
auto invoke_splitk_path = [&](auto split_k_) {
return mx_gemm_calc<GemmConfig,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleM,
ScaleN,
UsePersistentKernel,
split_k_.value,
has_hot_loop_v,
tail_num_v>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
};
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
: invoke_splitk_path(std::true_type{});
},
has_hot_loop,
tail_num);
// Simplified invocation - comp_async handles hot loop and tail internally
auto invoke_splitk_path = [&](auto split_k_) {
return mx_gemm_calc<GemmConfig,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleM,
ScaleN,
UsePersistentKernel,
split_k_.value>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
};
float ave_time = (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
: invoke_splitk_path(std::true_type{});
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;

View File

@@ -64,9 +64,9 @@ struct MxGemmConfig
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;

View File

@@ -5,29 +5,24 @@
#include "ck_tile/host.hpp"
#include "mx_gemm.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
template <typename Layout>
using is_row_major_t = ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
// Problem definition for MX GEMM with comp_async pipeline
// The comp_async pipeline handles MX scaling with OpSel parameters
template <typename ADataType,
typename BDataType,
typename CDataType,
typename BlockGemmShape,
typename Traits,
ck_tile::GemmPipelineScheduler Scheduler_ = ck_tile::GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
ck_tile::TailNumber TailNum_ = ck_tile::TailNumber::Full>
ck_tile::GemmPipelineScheduler Scheduler_ = ck_tile::GemmPipelineScheduler::Intrawave>
struct MXGemmPipelineProblem : ck_tile::GemmPipelineProblem<ADataType, BDataType, CDataType, BlockGemmShape, Traits>
{
static constexpr int MXdlPack = 1; // No M packing
static constexpr int NXdlPack = 1; // No N packing
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
};
template <typename GemmConfig,
@@ -41,9 +36,7 @@ template <typename GemmConfig,
typename ScaleM,
typename ScaleN,
bool persistent,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
bool Splitk>
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
@@ -80,11 +73,10 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
AccDataType,
GemmShape,
MXGemmTraits,
scheduler,
HasHotLoop,
TailNum>;
scheduler>;
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
// Use the new comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,