Files
composable_kernel/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
msaffari-amd 5348b577ed [rocm-libraries] ROCm/rocm-libraries#5863 (commit 31d9247)
[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue
 into standalone file (#5863)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The PermuteN epilogue was previously embedded within
cshuffle_epilogue.hpp, despite having fundamentally different behaviour.
Coupling these two independent strategies in one file introduced
unnecessary complexity, SFINAE guards, and a dual operator() overload
selected at compile time via TiledMMAPermuteN_ template parameter.

This PR separates PermuteN into its own standalone
file(pertmuten_epilogue.hpp), simplifying both implementations and
making the codebase easier to maintain and extend independently.

## Technical Details

**New file: permuten_epilogue.hpp:**
contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from
the permuteN code path in cshuffle_epilogue.hpp.

**Cleanup of cshuffle_epilogue.hpp:**

- Removed the TiledMMAPermuteN_ template parameter from
[CShuffleEpilogueProblem]
- Removed the SFINAE-guarded permuteN operator() overload
- Removed the EnablePermuateN_ SFINAE alias
- CShuffle now only contains CShuffle logic; EightWave support
(independent feature) is retained

**Consumer migration :**
All consumer files now use compile-time epilogue selection via
[std::conditional_t]

`using GemmEpilogue = std::conditional_t<
    TiledMMAPermuteN,
    PermuteNEpilogue<PermuteNEpilogueProblem<...>>,
    CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;`

**Files modified:**

- flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp,
mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples
- run_gemm_quant_example.inc — block-scale GEMM example
- gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker
- test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp,
test_gemm_pipeline_util.hpp — test utilities
- universal_gemm_invoker.hpp — universal GEMM invoker
- epilogue.hpp — add header updated to include permuten_epilogue.hpp

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-14 20:23:26 +00:00

762 lines
34 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <type_traits>
namespace ck_tile {
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
// ADataTypeCompute: compute type from Problem (may be tf32_t for TF32 mode)
using ADataTypeCompute = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataTypeCompute = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataTypeCompute, tf32_t, float, ADataTypeCompute>;
using BDataTypeBuf = if_select_t<BDataTypeCompute, tf32_t, float, BDataTypeCompute>;
// For warp gemm selection: use tf32_t if compute type was tf32_t
// For pk_int4/pk_fp4: use the other data type
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataTypeCompute, tf32_t>,
tf32_t,
std::conditional_t<std::is_same_v<ADataTypeBuf, pk_int4_t> ||
std::is_same_v<ADataTypeBuf, pk_fp4_t>,
BDataTypeBuf,
ADataTypeBuf>>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataTypeCompute, tf32_t>,
tf32_t,
std::conditional_t<std::is_same_v<BDataTypeBuf, pk_int4_t> ||
std::is_same_v<BDataTypeBuf, pk_fp4_t> ||
sizeof(BDataTypeBuf) < sizeof(ADataTypeBuf),
ADataTypeBuf,
BDataTypeBuf>>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
#if defined(__gfx9__)
static constexpr bool EightWave = (MWave * NWave == 8);
#else
static constexpr bool EightWave = false;
#endif
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "CShuffleEpilogue",
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
isCTransposed ? "CTransposed" : "CNotTransposed");
// clang-format on
}
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters check and aligment
*
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
*/
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
{
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
return shuffle_tile;
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread <= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
static_assert(elem_per_thread % GetVectorSizeC() == 0);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
kMPerBlock / (MPerXdl * MWave)),
1>();
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return AlignShuffleTileWithSmem<1,
min(num_xdl_shuffles,
kNPerBlock / (NPerXdl * NWave))>();
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
constexpr auto DataTypeSize = sizeof(ODataType);
constexpr index_t VectorLen = GetVectorSizeC();
constexpr index_t banks = get_n_lds_banks();
constexpr index_t BytesPerBank = 4;
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
constexpr index_t MLdsLayerRequired =
banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
constexpr auto MLdsLayer = max(1, MLdsLayerRequired);
constexpr index_t BaseStrideElems = NPerIterationShuffle * MLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
// calculate how many elements to pad to avoid bank conflict
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle / MLdsLayer>{},
number<NPerIterationShuffle / VectorLen * MLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<NPerIterationShuffle * MLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MPerIterationShuffle / MLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<MLdsLayer>{}, number<NPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
#else
constexpr auto PaddingAmount = 0;
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle + PaddingAmount>{}, number<1>{}),
number<VectorLen>{},
number<1>{});
return lds_block_desc;
#endif
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t NLdsLayerRequired =
get_n_lds_banks() * BytesPerBank / MPerIterationShuffle / DataTypeSize;
constexpr auto NLdsLayer = max(1, NLdsLayerRequired);
constexpr index_t BaseStrideElems = MPerIterationShuffle * NLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NPerIterationShuffle / NLdsLayer>{},
number<MPerIterationShuffle / VectorLen * NLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<MPerIterationShuffle * NLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NPerIterationShuffle / NLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<NLdsLayer>{}, number<MPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
#if defined(__gfx950__)
constexpr auto is_950 = true;
#else
constexpr auto is_950 = false;
#endif
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
// this branch is for original a16w4
if constexpr(is_950 || is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
{
if constexpr(EightWave)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
}
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
}
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Check if scales are EmptyScale first (no scaling needed)
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
{
// No scaling needed - this is a no-op
}
// Check if scales are scalar AccDataType
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
std::is_same_v<ScaleN, AccDataType>)
{
// Handle scalar scales
const AccDataType scale_m = scale_m_window;
const AccDataType scale_n = scale_n_window;
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
lds_tile);
}
// Otherwise, assume they are tile windows that can be loaded
else
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
}
template <index_t iAccess, typename OAccTile, typename LdsTile>
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
{
constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
}
template <typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
{
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
}
template <typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
{
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
}
template <typename OutDramWindow, typename COutTensor>
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
const COutTensor& c_out_tensor)
{
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
}
/**
* @brief Move both the output and D tensors windows for the next access.
*/
template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
{
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
// move the output dram window
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
// move windows for each of the D matrices (inputs for element-wise)
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
});
}
}
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
struct EmptyScale
{
};
template <typename, typename = void>
struct ScaleDataType
{
using DataType = float;
};
template <typename T>
struct ScaleDataType<T, std::void_t<typename T::DataType>>
{
using DataType = typename T::DataType;
};
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr bool has_scales =
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
auto scale_m_window = [&]() {
if constexpr(has_scalar_scales)
{
return scale_m;
}
else if constexpr(has_scales)
{
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scalar_scales)
{
return scale_n;
}
else if constexpr(has_scales)
{
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
if constexpr(has_scales)
{
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
}
cast_lds_tile(lds_tile, in_lds_window);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
apply_d_tensors(d_dram_windows, c_out_tensor);
store_to_dram(out_dram_window, c_out_tensor);
move_windows<iAccess>(out_dram_window, d_dram_windows);
});
}
};
} // namespace ck_tile