// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include namespace ck_tile { template struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; static constexpr index_t NWave = NWave_; static constexpr index_t MPerXdl = MPerXdl_; static constexpr index_t NPerXdl = NPerXdl_; static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); }; template struct CShuffleEpilogue { using Problem = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using ATypeToUse = std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock; static constexpr index_t MWave = Problem::MWave; static constexpr index_t NWave = Problem::NWave; static constexpr index_t MPerXdl = Problem::MPerXdl; static constexpr index_t NPerXdl = Problem::NPerXdl; static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; static constexpr index_t NumDTensor = Problem::NumDTensor; static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); /** * @brief Get the vector store size for C tensor. * * @note The vector store size for output C tensor would depend on multiple factors * like its data layout and warp gemm C transposition. In general it would * be the number of consecutive elements in contiguous C dimension hold by * single thread. * * @return The vector store size for C tensor. */ CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { if constexpr(FixedVectorSize) { return VectorSizeC; } constexpr index_t max_vector_size = 16; if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else { static_assert(false, "Unsupported ELayout!"); } } /** * @brief Get the vector store size for Di tensor. * * @return The vector store size for Di tensor. */ template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) { constexpr index_t max_vector_size = 16; using DiDataType = remove_cvref_t>; using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); } else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); } else { static_assert(false, "Unsupported DLayout!"); } return max_vector_size / sizeof(DiDataType); } /** * @brief Shuffle tile configuration parameters * * @details These parameters control the number of XDL tiles processed per wave in each shuffle * iteration: * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave */ static constexpr auto shuffle_tile_tuple = [] { constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); if constexpr(elem_per_thread >= GetVectorSizeC()) { return std::make_tuple(1, 1); } else { constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), "kMPerBlock must be divisible by MPerXdl*MWave and " "num_xdl_shuffles for CShuffleEpilogue"); return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); } else { static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && (kNPerBlock % num_xdl_shuffles == 0), "kNPerBlock must be divisible by NPerXdl*NWave and " "num_xdl_shuffles for CShuffleEpilogue"); return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); } } }(); static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); static constexpr auto MNPerIterationShuffle = [] { constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); else return std::make_tuple(m_val, n_val); }(); static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); using WG = WarpGemmDispatcher; using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; using SFC = space_filling_curve, sequence<0, 1>, sequence>; template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { // N is contiguous dimension if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number{}, number<1>{})); } // M is contiguous dimension else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number<1>{}, number{})); } else { static_assert(false, "Unsupported ELayout!"); } } CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() { constexpr auto block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, sequence<1, 2>, sequence<0, 0>>{}; constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); return block_dstr_encoding; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } template CK_TILE_DEVICE void scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window) { // Load tiles const auto scale_m_tile = load_tile(scale_m_window); const auto scale_n_tile = load_tile(scale_n_window); // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n tile_elementwise_inout( element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile); // Move scale windows constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); } } template CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile) { constexpr auto idx_y_start = SFC::get_index(iAccess); constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( merge_sequences( sequence{}, c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths)); } template CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window) { const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); store_tile(in_lds_window, c_warptile_in_tensor_casted); } template CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor) { const auto ds_tensor = generate_tuple( [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); } template CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window, const COutTensor& c_out_tensor) { if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } else { update_tile(out_dram_window, c_out_tensor); } } /** * @brief Move both the output and D tensors windows for the next access. */ template CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows) { constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); // move the output dram window move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); // move windows for each of the D matrices (inputs for element-wise) static_for<0, NumDTensor, 1>{}([&](auto idx) { move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})}); }); } } // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t struct EmptyScale { }; template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, const ScaleM& scale_m = {}, const ScaleN& scale_n = {}) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); auto lds_tile = make_static_distributed_tensor(LdsTileDistr); constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); auto in_lds_window = make_tile_window( o_lds_block, make_tuple(number{}, number{}), {0, 0}, LdsTileDistr); auto out_lds_window = make_tile_window( o_lds_block, make_tuple(number{}, number{}), {0, 0}); constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = tile_distribution_encoding_pattern_2d; constexpr auto dram_tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); }, number{}); constexpr bool has_scales = !std::is_same::value && !std::is_same::value; auto scale_m_window = [&]() { if constexpr(has_scales) { return make_tile_window(scale_m, lds_tile.get_tile_distribution()); } else { return EmptyScale{}; } }(); auto scale_n_window = [&]() { if constexpr(has_scales) { return make_tile_window(scale_n, lds_tile.get_tile_distribution()); } else { return EmptyScale{}; } }(); static_for<0, num_access, 1>{}([&](auto iAccess) { block_sync_lds(); slice_acc_tile(o_acc_tile, lds_tile); if constexpr(has_scales) { scale_tile(lds_tile, scale_m_window, scale_n_window); } cast_lds_tile(lds_tile, in_lds_window); block_sync_lds(); auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); apply_d_tensors(d_dram_windows, c_out_tensor); store_to_dram(out_dram_window, c_out_tensor); move_windows(out_dram_window, d_dram_windows); }); } }; } // namespace ck_tile