This commit is contained in:
Sami Remes
2026-01-14 12:07:26 -05:00
parent 5d4e07e095
commit f6f9931541
5 changed files with 181 additions and 122 deletions

View File

@@ -92,8 +92,8 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
MXGemmTraits,
GemmConfig::Scheduler>;
// Use the new comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// Use the new MX comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// Simplified invocation - comp_async handles hot loop and tail internally
auto invoke_splitk_path = [&](auto split_k_) {

View File

@@ -25,6 +25,15 @@ struct MXGemmPipelineProblem : ck_tile::GemmPipelineProblem<ADataType, BDataType
static constexpr auto Scheduler = Scheduler_;
};
// Epilogue wrapper that adds MemoryOperation member for MX GEMM kernel compatibility
template <typename BaseEpilogue_, ck_tile::memory_operation_enum MemOp_>
struct MXGemmEpilogueWrapper : BaseEpilogue_
{
static constexpr ck_tile::memory_operation_enum MemoryOperation = MemOp_;
using BaseEpilogue_::BaseEpilogue_;
using BaseEpilogue_::operator();
};
template <typename GemmConfig,
typename ADataType,
typename BDataType,
@@ -75,15 +84,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
MXGemmTraits,
scheduler>;
// Use the new comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
// Use the new MX comp_async pipeline with MX scaling support
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmEpilogue =
using BaseEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
ck_tile::tuple<>, // DsDataType
@@ -100,11 +109,14 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups,
GemmConfig::NumWaveGroups, // kNumWaveGroups
false, // FixedVectorSize
1, // VectorSizeC
false>>; // PermuteN
false, // TiledMMAPermuteN
1, // BlockedXDLN_PerWarp
false>>; // DoubleSmemBuffer
using GemmEpilogue = MXGemmEpilogueWrapper<BaseEpilogue, memory_operation>;
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;