mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
use new pipeline in example
This commit is contained in:
@@ -90,45 +90,31 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler,
|
||||
true, // HasHotLoop
|
||||
ck_tile::TailNumber::Full>;
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
|
||||
// Use the new comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = MXGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = MXGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time = MXGemmPipeline::template TailHandler<true>(
|
||||
[&](auto has_hot_loop_, auto) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = ck_tile::TailNumber::Full;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
split_k_.value,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
split_k_.value>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
|
||||
float ave_time = (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
@@ -64,9 +64,9 @@ struct MxGemmConfig
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer = true; // comp_async uses double buffer
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/mx_pipeline_ag_bg_cr_v1.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
using is_row_major_t = ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
// Problem definition for MX GEMM with comp_async pipeline
|
||||
// The comp_async pipeline handles MX scaling with OpSel parameters
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename BlockGemmShape,
|
||||
typename Traits,
|
||||
ck_tile::GemmPipelineScheduler Scheduler_ = ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
ck_tile::TailNumber TailNum_ = ck_tile::TailNumber::Full>
|
||||
ck_tile::GemmPipelineScheduler Scheduler_ = ck_tile::GemmPipelineScheduler::Intrawave>
|
||||
struct MXGemmPipelineProblem : ck_tile::GemmPipelineProblem<ADataType, BDataType, CDataType, BlockGemmShape, Traits>
|
||||
{
|
||||
static constexpr int MXdlPack = 1; // No M packing
|
||||
static constexpr int NXdlPack = 1; // No N packing
|
||||
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
};
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -41,9 +36,7 @@ template <typename GemmConfig,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
bool Splitk>
|
||||
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
@@ -80,11 +73,10 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
scheduler,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
scheduler>;
|
||||
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrV1<MXPipelineProblem>;
|
||||
// Use the new comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
|
||||
Reference in New Issue
Block a user