// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include namespace ck_tile { template concept HasDataType = requires { typename T::DataType; }; template struct GetDataType { using type = float; }; template requires HasDataType struct GetDataType { using type = typename T::DataType; // Use T::ScaleN::DataType }; template struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; static constexpr index_t NWave = NWave_; static constexpr index_t MPerXdl = MPerXdl_; static constexpr index_t NPerXdl = NPerXdl_; static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); }; template struct CShuffleEpilogue { using Problem = remove_cvref_t; using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; static constexpr bool ADataTypeIsTuple = is_detected::value; static constexpr bool BDataTypeIsTuple = is_detected::value; using AsDataTypeTuple = std::conditional_t, remove_cvref_t>>; using BsDataTypeTuple = std::conditional_t, remove_cvref_t>>; using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; using ATypeToUse = std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock; static constexpr index_t MWave = Problem::MWave; static constexpr index_t NWave = Problem::NWave; static constexpr index_t MPerXdl = Problem::MPerXdl; static constexpr index_t NPerXdl = Problem::NPerXdl; static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; static constexpr index_t NumDTensor = Problem::NumDTensor; static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); /** * @brief Get the vector store size for C tensor. * * @note The vector store size for output C tensor would depend on multiple factors * like its data layout and warp gemm C transposition. In general it would * be the number of consecutive elements in contiguous C dimension hold by * single thread. * * @return The vector store size for C tensor. */ CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { if constexpr(FixedVectorSize) { return VectorSizeC; } constexpr index_t max_vector_size = 16; if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else { static_assert(false, "Unsupported ELayout!"); } } /** * @brief Get the vector store size for Di tensor. * * @return The vector store size for Di tensor. */ template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) { constexpr index_t max_vector_size = 16; using DiDataType = remove_cvref_t>; using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); } else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); } else { static_assert(false, "Unsupported DLayout!"); } return max_vector_size / sizeof(DiDataType); } /** * @brief Shuffle tile configuration parameters * * @details These parameters control the number of XDL tiles processed per wave in each shuffle * iteration: * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave */ static constexpr auto shuffle_tile_tuple = [] { constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); if constexpr(elem_per_thread >= GetVectorSizeC()) { return std::make_tuple(1, 1); } else { constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), "kMPerBlock must be divisible by MPerXdl*MWave and " "num_xdl_shuffles for CShuffleEpilogue"); return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); } else { static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && (kNPerBlock % num_xdl_shuffles == 0), "kNPerBlock must be divisible by NPerXdl*NWave and " "num_xdl_shuffles for CShuffleEpilogue"); return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); } } }(); static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); static constexpr auto MNPerIterationShuffle = [] { constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); else return std::make_tuple(m_val, n_val); }(); static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); using WG = WarpGemmDispatcher; using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; using SFC = space_filling_curve, sequence<0, 1>, sequence>; template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { // N is contiguous dimension if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number{}, number<1>{})); } // M is contiguous dimension else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number<1>{}, number{})); } else { static_assert(false, "Unsupported ELayout!"); } } CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() { constexpr auto block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, sequence<1, 2>, sequence<0, 0>>{}; constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); return block_dstr_encoding; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } template CK_TILE_DEVICE void scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window) { // Check if scales are EmptyScale first (no scaling needed) if constexpr(std::is_same_v && std::is_same_v) { // No scaling needed - this is a no-op } // Check if scales are scalar AccDataType else if constexpr(std::is_same_v && std::is_same_v) { // Handle scalar scales const AccDataType scale_m = scale_m_window; const AccDataType scale_n = scale_n_window; tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; }, lds_tile); } // Otherwise, assume they are tile windows that can be loaded else { // Load tiles const auto scale_m_tile = load_tile(scale_m_window); const auto scale_n_tile = load_tile(scale_n_window); // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n tile_elementwise_inout( element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile); // Move scale windows constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); } } } template CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile) { constexpr auto idx_y_start = SFC::get_index(iAccess); constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( merge_sequences( sequence{}, c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths)); } template CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window) { const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); store_tile(in_lds_window, c_warptile_in_tensor_casted); } template CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor) { const auto ds_tensor = generate_tuple( [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); } template CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window, const COutTensor& c_out_tensor) { if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } else { update_tile(out_dram_window, c_out_tensor); } } /** * @brief Move both the output and D tensors windows for the next access. */ template CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows) { constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); // move the output dram window move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); // move windows for each of the D matrices (inputs for element-wise) static_for<0, NumDTensor, 1>{}([&](auto idx) { move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})}); }); } } // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t struct EmptyScale { }; template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* /*p_smem*/, const ScaleM& scale_m = {}, const ScaleN& scale_n = {}) { constexpr int kM0 = MWave; constexpr int kM2 = 4; constexpr int kM1 = MPerXdl / kM2; constexpr int kN0 = NWave; constexpr int kN1 = NPerXdl; constexpr int kN2 = NRepeat; using IntrThreadShuffleEncode = tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 1>>, sequence<1, 2>, sequence<2, 2>>; constexpr auto dram_tile_distribution = make_static_tile_distribution(IntrThreadShuffleEncode{}); auto d_dram_windows = generate_tuple( [&](auto idx) { return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); }, number{}); constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); // Optional scales (must share the same distribution to match per-thread indexing) constexpr bool has_scales = !std::is_same::value && !std::is_same::value; constexpr bool has_scalar_scales = std::is_same_v && std::is_same_v; // Tiles to hold row/col scales when present using SMType = typename GetDataType>::type; using SNType = typename GetDataType>::type; auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); // Build windows only if scales are provided auto scale_m_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { static_assert( IsLoadableTile, "ScaleM must be a loadable tile"); return make_tile_window(scale_m, dram_tile_distribution); } else { return EmptyScale{}; } }(); auto scale_n_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { static_assert( IsLoadableTile, "ScaleN must be a loadable tile"); return make_tile_window(scale_n, dram_tile_distribution); } else { return EmptyScale{}; } }(); static_for<0, MRepeat, 1>{}([&](auto mIter) { // Slice accumulators for this M repeat into the permuted layout shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); // If scales provided, load them with identical distribution if constexpr(has_scales && IsLoadableTile && IsLoadableTile) { sm_tile = load_tile(scale_m_window); // row scales in permuted layout sn_tile = load_tile(scale_n_window); // col scales in permuted layout } // Pack 4 “rows per lane” as you already do static_for<0, NRepeat, 1>{}([&](auto n_idx) { // source indices in shuffle_acc: (n_idx * product(Y) + row) const index_t base = n_idx * c_warp_y_lengths.product(); // local lambda to fuse scale (if present) and convert auto emit = [&](index_t out_idx, index_t src_row) { AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row]; if constexpr(has_scalar_scales) { v = static_cast(v * scale_m * scale_n); } else if constexpr(has_scales) { // same linear index mapping on the permuted distribution const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); const auto s_n = static_cast(sn_tile.get_thread_buffer()[out_idx]); v = static_cast(v * s_m * s_n); } c_out_tensor.get_thread_buffer()[out_idx] = type_convert(v); }; // Your current packing pattern (rows 0..3, spaced by NRepeat) emit(n_idx + 0 * NRepeat, 0); emit(n_idx + 1 * NRepeat, 1); emit(n_idx + 2 * NRepeat, 2); emit(n_idx + 3 * NRepeat, 3); }); // store/update if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } else { update_tile(out_dram_window, c_out_tensor); } // advance output (and any D-tensors) by one MPerXdl*MWave chunk move_tile_window(out_dram_window, {number{}, number<0>{}}); static_for<0, NumDTensor, 1>{}([&](auto idx) { move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); }); }); } template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, const ScaleM& scale_m = {}, const ScaleN& scale_n = {}) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); auto lds_tile = make_static_distributed_tensor(LdsTileDistr); constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); auto in_lds_window = make_tile_window( o_lds_block, make_tuple(number{}, number{}), {0, 0}, LdsTileDistr); auto out_lds_window = make_tile_window( o_lds_block, make_tuple(number{}, number{}), {0, 0}); constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = tile_distribution_encoding_pattern_2d; constexpr auto dram_tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); }, number{}); constexpr bool has_scales = !std::is_same_v && !std::is_same_v; constexpr bool has_scalar_scales = std::is_same_v && std::is_same_v; auto scale_m_window = [&]() { if constexpr(has_scalar_scales) { return scale_m; } else if constexpr(has_scales) { static_assert( IsLoadableTile, "ScaleM must be a loadable tile"); return make_tile_window(scale_m, lds_tile.get_tile_distribution()); } else { return EmptyScale{}; } }(); auto scale_n_window = [&]() { if constexpr(has_scalar_scales) { return scale_n; } else if constexpr(has_scales) { static_assert( IsLoadableTile, "ScaleN must be a loadable tile"); return make_tile_window(scale_n, lds_tile.get_tile_distribution()); } else { return EmptyScale{}; } }(); static_for<0, num_access, 1>{}([&](auto iAccess) { block_sync_lds(); slice_acc_tile(o_acc_tile, lds_tile); if constexpr(has_scales) { scale_tile(lds_tile, scale_m_window, scale_n_window); } cast_lds_tile(lds_tile, in_lds_window); block_sync_lds(); auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); apply_d_tensors(d_dram_windows, c_out_tensor); store_to_dram(out_dram_window, c_out_tensor); move_windows(out_dram_window, d_dram_windows); }); } }; } // namespace ck_tile