From 6072031cf46e5b88c101a0e6b9dbbaca298aa8d2 Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Tue, 14 Apr 2026 22:22:18 +0200 Subject: [PATCH] [CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The PermuteN epilogue was previously embedded within cshuffle_epilogue.hpp, despite having fundamentally different behaviour. Coupling these two independent strategies in one file introduced unnecessary complexity, SFINAE guards, and a dual operator() overload selected at compile time via TiledMMAPermuteN_ template parameter. This PR separates PermuteN into its own standalone file(pertmuten_epilogue.hpp), simplifying both implementations and making the codebase easier to maintain and extend independently. ## Technical Details **New file: permuten_epilogue.hpp:** contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from the permuteN code path in cshuffle_epilogue.hpp. **Cleanup of cshuffle_epilogue.hpp:** - Removed the TiledMMAPermuteN_ template parameter from [CShuffleEpilogueProblem] - Removed the SFINAE-guarded permuteN operator() overload - Removed the EnablePermuateN_ SFINAE alias - CShuffle now only contains CShuffle logic; EightWave support (independent feature) is retained **Consumer migration :** All consumer files now use compile-time epilogue selection via [std::conditional_t] `using GemmEpilogue = std::conditional_t< TiledMMAPermuteN, PermuteNEpilogue>, CShuffleEpilogue>>;` **Files modified:** - flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp, mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples - run_gemm_quant_example.inc — block-scale GEMM example - gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker - test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp, test_gemm_pipeline_util.hpp — test utilities - universal_gemm_invoker.hpp — universal GEMM invoker - epilogue.hpp — add header updated to include permuten_epilogue.hpp ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../gemm_weight_preshuffle_invoker.hpp | 60 ++- .../03_gemm/universal_gemm_invoker.hpp | 2 - example/ck_tile/18_flatmm/flatmm_basic.cpp | 61 ++- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 64 ++- .../mixed_prec/mixed_prec_flatmm.cpp | 64 ++- example/ck_tile/18_flatmm/moe_flatmm.cpp | 64 ++- .../18_flatmm/mxgemm/mx_flatmm_instance.hpp | 24 +- .../run_gemm_quant_example.inc | 59 ++- include/ck_tile/ops/epilogue.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 150 +------ .../ops/epilogue/permuten_epilogue.hpp | 375 ++++++++++++++++++ test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 1 - .../test_gemm_quant_fixtures.hpp | 125 ++++-- .../test_gemm_persistent_async_input.cpp | 11 +- 14 files changed, 728 insertions(+), 333 deletions(-) create mode 100644 include/ck_tile/ops/epilogue/permuten_epilogue.hpp diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index 1deafb97a1..e4efd5763f 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -58,27 +58,45 @@ struct WeightPreshuffleInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std::conditional_t< + GemmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 660647dda9..1f98ed575d 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -84,7 +84,6 @@ struct UniversalInvoker GemmConfig::NumWaveGroups, false, /*FixedVectorSize_*/ 1, /*VectorSizeC_*/ - false, /*TiledMMAPermuteN_*/ 1, /*BlockedXDLN_PerWarp_*/ GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>; @@ -228,7 +227,6 @@ struct UniversalInvoker GemmConfig::NumWaveGroups, false, /*FixedVectorSize_*/ 1, /*VectorSizeC_*/ - false, /*TiledMMAPermuteN_*/ 1, /*BlockedXDLN_PerWarp_*/ GemmConfig::DoubleSmemBuffer>>; diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 19593a0f04..6295a4a48b 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std::conditional_t< + FlatmmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. @@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, << "Shape: " << CodegenFlatmmShape::GetName() << "\n" << "problem: " << CodegenPipelineProblem::GetName() << "\n" << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" + << "epilogue: " << GemmEpilogue::GetName() << "\n" << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 708e8a683e..a1d3024364 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -139,28 +139,48 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std::conditional_t< + FlatmmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; using CodegenFlatmmPipeline = std::conditional_t< MXFP4_Pipeline, diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index f9f8c0cec7..b7a5818afd 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -108,28 +108,48 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& using CodegenFlatmmPipeline = ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std::conditional_t< + FlatmmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem>, // VectorSizeC + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; using Kernel = ck_tile::F16xMXF4FlatmmKernel; diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 4cca953066..4fb082cb9d 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -163,28 +163,48 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, ? 2 : 1; // determined by scale shuffle pattern - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = std::conditional_t< + FlatmmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; using CodegenFlatmmPipeline = ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index 90bd24d5dc..54e27d0baa 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -84,7 +84,26 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, ck_tile::GemmSpatiallyLocalTilePartitioner; - using GemmEpilogue = + using GemmEpilogue = std::conditional_t< + FlatmmConfig::TiledMMAPermuteN, + ck_tile::PermuteNEpilogue>, // VectorSizeC ck_tile::CShuffleEpilogue& args, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC - FlatmmConfig::TiledMMAPermuteN, - BlockedXDLN_PerWarp>>; + BlockedXDLN_PerWarp>>>; using Kernel = ck_tile::MXFlatmmKernel; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index d89aa37ff8..46df80ae28 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -207,27 +207,44 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + using GemmEpilogue = std::conditional_t< + TiledPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + false, + 1>>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c>>>; using Kernel = ck_tile::QuantGemmKernel; diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index d1b38a8bca..b7a119d756 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/epilogue/permuten_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index fba831e205..b0e55d239f 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -33,7 +33,6 @@ template struct CShuffleEpilogueProblem @@ -59,7 +58,6 @@ struct CShuffleEpilogueProblem static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -658,152 +656,8 @@ struct CShuffleEpilogue template = 0> - CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* /* p_smem */, - const ScaleM& scale_m = {}, - const ScaleN& scale_n = {}) - { - static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size(); - - static_assert(MPerXdl % RowsPerLane == 0, - "CShuffle (permuteN): MPerXdl must be divisible by per-lane row count."); - constexpr int kM0 = MWave; - constexpr int kM2 = RowsPerLane; - constexpr int kM1 = MPerXdl / kM2; - - constexpr int kN0 = NWave; - constexpr int kN1 = NPerXdl; - constexpr int kN2 = NRepeat; - - using IntrThreadShuffleEncode = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>; - constexpr auto dram_tile_distribution = - make_static_tile_distribution(IntrThreadShuffleEncode{}); - - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); - auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - - // Optional scales (must share the same distribution to match per-thread indexing) - constexpr bool has_scales = - !std::is_same::value && !std::is_same::value; - constexpr bool has_scalar_scales = - std::is_same_v && std::is_same_v; - - // Tiles to hold row/col scales when present - using SMType = typename ScaleDataType::DataType; - using SNType = typename ScaleDataType::DataType; - - auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); - auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); - - // Build windows only if non-scalar scales are provided - auto scale_m_window = [&]() { - if constexpr(has_scales && !has_scalar_scales) - { - return make_tile_window(scale_m, dram_tile_distribution); - } - else - { - return EmptyScale{}; - } - }(); - auto scale_n_window = [&]() { - if constexpr(has_scales && !has_scalar_scales) - { - return make_tile_window(scale_n, dram_tile_distribution); - } - else - { - return EmptyScale{}; - } - }(); - - static_for<0, MRepeat, 1>{}([&](auto mIter) { - // Slice accumulators for this M repeat into the permuted layout - shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - - // If non-scalar scales provided, load them with identical distribution - if constexpr(has_scales && !has_scalar_scales) - { - sm_tile = load_tile(scale_m_window); // row scales in permuted layout - sn_tile = load_tile(scale_n_window); // col scales in permuted layout - } - - // Pack 4 “rows per lane” as you already do - static_for<0, NRepeat, 1>{}([&](auto n_idx) { - // source indices in shuffle_acc: (n_idx * product(Y) + row) - const index_t plane = c_warp_y_lengths.product(); - - // local lambda to fuse scale (if present) and convert - static_for<0, kM2, 1>{}([&](auto m_lane) { - const int src = n_idx * plane + m_lane; // source row in this N-plane - const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output - AccDataType v = shuffle_acc.get_thread_buffer()[src]; - - if constexpr(has_scalar_scales) - { - v = static_cast(v * scale_m * scale_n); - } - else if constexpr(has_scales && !has_scalar_scales) - { - const auto sm = static_cast(sm_tile.get_thread_buffer()[dst]); - const auto sn = static_cast(sn_tile.get_thread_buffer()[dst]); - v = static_cast(v * sm * sn); - } - - c_out_tensor.get_thread_buffer()[dst] = type_convert(v); - }); - }); - - // store/update - if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == - memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - - // advance output (and any D-tensors) by one MPerXdl*MWave chunk - move_tile_window(out_dram_window, {number{}, number<0>{}}); - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); - }); - }); - } - - template = 0> + typename ScaleM = EmptyScale, + typename ScaleN = EmptyScale> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, diff --git a/include/ck_tile/ops/epilogue/permuten_epilogue.hpp b/include/ck_tile/ops/epilogue/permuten_epilogue.hpp new file mode 100644 index 0000000000..ffcae1b821 --- /dev/null +++ b/include/ck_tile/ops/epilogue/permuten_epilogue.hpp @@ -0,0 +1,375 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/utils.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#include + +namespace ck_tile { + +template +struct PermuteNEpilogueProblem +{ + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); +}; + +template +struct PermuteNEpilogue +{ + using Problem = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + + using ATypeToUse = std::conditional_t || + std::is_same_v, + BDataType, + ADataType>; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using BTypeToUse = std::conditional_t || + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), + ADataType, + BDataType>; + + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + + CDElementwise elfunc_; + + // PermuteN epilogue does not support D tensors or non-passthrough elementwise operations. + // If D tensor support is needed, use CShuffleEpilogue instead. + static_assert(NumDTensor == 0, + "PermuteNEpilogue does not support D tensors. Use CShuffleEpilogue instead."); + static_assert(std::is_same_v, + "PermuteNEpilogue only supports PassThrough elementwise. " + "Use CShuffleEpilogue for custom elementwise operations."); + + CK_TILE_DEVICE PermuteNEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {}; + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "PermuteNEpilogue", + concat('x', MWave, NWave), + concat('x', MPerXdl, NPerXdl, KPerXdl), + VectorSizeC, + isCTransposed ? "CTransposed" : "CNotTransposed"); + // clang-format on + } + + /** + * @brief Get the vector store size for C tensor. + * + * @note The vector store size for output C tensor would depend on multiple factors + * like its data layout and warp gemm C transposition. In general it would + * be the number of consecutive elements in contiguous C dimension hold by + * single thread. + * + * @return The vector store size for C tensor. + */ + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() + { + if constexpr(FixedVectorSize) + { + return VectorSizeC; + } + constexpr index_t max_vector_size = 16; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); + } + else + { + static_assert(false, "Unsupported ELayout!"); + } + } + + /** + * @brief Get the vector store size for Di tensor. + * + * @return The vector store size for Di tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) + { + constexpr index_t max_vector_size = 16; + using DiDataType = remove_cvref_t>; + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else + { + static_assert(false, "Unsupported DLayout!"); + } + return max_vector_size / sizeof(DiDataType); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + using WG = WarpGemmDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; + + // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t + struct EmptyScale + { + }; + + template + struct ScaleDataType + { + using DataType = float; + }; + + template + struct ScaleDataType> + { + using DataType = typename T::DataType; + }; + + template + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* /* p_smem */, + const ScaleM& scale_m = {}, + const ScaleN& scale_n = {}) + { + static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size(); + + static_assert(MPerXdl % RowsPerLane == 0, + "PermuteN: MPerXdl must be divisible by per-lane row count."); + constexpr int kM0 = MWave; + constexpr int kM2 = RowsPerLane; + constexpr int kM1 = MPerXdl / kM2; + + constexpr int kN0 = NWave; + constexpr int kN1 = NPerXdl; + constexpr int kN2 = NRepeat; + + using IntrThreadShuffleEncode = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>; + constexpr auto dram_tile_distribution = + make_static_tile_distribution(IntrThreadShuffleEncode{}); + + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); + auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); + + // Optional scales (must share the same distribution to match per-thread indexing) + constexpr bool has_scales = + !std::is_same::value && !std::is_same::value; + constexpr bool has_scalar_scales = + std::is_same_v && std::is_same_v; + + // Tiles to hold row/col scales when present + using SMType = typename ScaleDataType::DataType; + using SNType = typename ScaleDataType::DataType; + + auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); + auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); + + // Build windows only if non-scalar scales are provided + auto scale_m_window = [&]() { + if constexpr(has_scales && !has_scalar_scales) + { + return make_tile_window(scale_m, dram_tile_distribution); + } + else + { + return EmptyScale{}; + } + }(); + auto scale_n_window = [&]() { + if constexpr(has_scales && !has_scalar_scales) + { + return make_tile_window(scale_n, dram_tile_distribution); + } + else + { + return EmptyScale{}; + } + }(); + + static_for<0, MRepeat, 1>{}([&](auto mIter) { + // Slice accumulators for this M repeat into the permuted layout + shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); + + // If non-scalar scales provided, load them with identical distribution + if constexpr(has_scales && !has_scalar_scales) + { + sm_tile = load_tile(scale_m_window); // row scales in permuted layout + sn_tile = load_tile(scale_n_window); // col scales in permuted layout + } + + // Pack "rows per lane" with permuted N layout + static_for<0, NRepeat, 1>{}([&](auto n_idx) { + // source indices in shuffle_acc: (n_idx * product(Y) + row) + const index_t plane = c_warp_y_lengths.product(); + + // Fuse scale (if present) and convert + static_for<0, kM2, 1>{}([&](auto m_lane) { + const int src = n_idx * plane + m_lane; // source row in this N-plane + const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output + AccDataType v = shuffle_acc.get_thread_buffer()[src]; + + if constexpr(has_scalar_scales) + { + v = static_cast(v * scale_m * scale_n); + } + else if constexpr(has_scales && !has_scalar_scales) + { + const auto sm = static_cast(sm_tile.get_thread_buffer()[dst]); + const auto sn = static_cast(sn_tile.get_thread_buffer()[dst]); + v = static_cast(v * sm * sn); + } + + c_out_tensor.get_thread_buffer()[dst] = type_convert(v); + }); + }); + + // store/update + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) + { + store_tile(out_dram_window, c_out_tensor); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + + // advance output (and any D-tensors) by one MPerXdl*MWave chunk + move_tile_window(out_dram_window, {number{}, number<0>{}}); + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); + }); + }); + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a4f06bed67..30d5b4f241 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -221,7 +221,6 @@ class TestCkTileGemmPipeline : public ::testing::Test 1, /*kNumWaveGroups_*/ false, /*FixedVectorSize_*/ 1, /*VectorSizeC_*/ - false, /*TiledMMAPermuteN_*/ 1, /*BlockedXDLN_PerWarp_*/ DoubleSmemBuffer /*DoubleSmemBuffer*/>>; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index b354d04219..8fbda4a3ce 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -937,29 +937,49 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase>, ck_tile::WPQuantBPipelineAgBgCrV2>; - using GemmEpilogue = ck_tile::CShuffleEpilogue, - ADataType, - BDataType>, - ck_tile::tuple<>, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - Base::M_Warp, - Base::N_Warp, - Base::M_Warp_Tile, - Base::N_Warp_Tile, - Base::K_Warp_Tile, - false, // transpose_c - 1, - false, - 1, - TiledMMAPermuteN>>; + // clang-format off + using BTypeForEpilogue = + std::conditional_t, ADataType, BDataType>; + // clang-format on + + using GemmEpilogue = std::conditional_t< + TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false, // transpose_c + false, + 1>>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false>>>; // transpose_c using Kernel = ck_tile::QuantGemmKernel, ck_tile::ABQuantGemmPipelineAgBgCrCompV3>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - Base::M_Warp, - Base::N_Warp, - Base::M_Warp_Tile, - Base::N_Warp_Tile, - Base::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledMMAPermuteN>>; + using GemmEpilogue = std::conditional_t< + TiledMMAPermuteN, + ck_tile::PermuteNEpilogue< + ck_tile::PermuteNEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c, + false, + 1>>, + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c>>>; using Kernel = ck_tile::QuantGemmKernel>; + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer /*DoubleSmemBuffer*/>>; using Kernel = ck_tile::GemmKernel;