fix after merge ginolu/add_wgmfma_dispatcher

This commit is contained in:
mtgu0705
2025-09-09 04:37:42 -05:00
parent f119c30317
commit b0d71b8d19
9 changed files with 1037 additions and 339 deletions

View File

@@ -0,0 +1,760 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
static_assert(GetVectorSizeC() % kN2 == 0);
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
static_for<0, MRepeat, 1>{}([&](auto mIter) {
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// transpose <kM2 x NRepeat> thread matrix
c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]);
c_out_tensor.get_thread_buffer()[n_idx + 2 * NRepeat] = type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]);
c_out_tensor.get_thread_buffer()[n_idx + 3 * NRepeat] = type_convert<ODataType>(
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
});
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleMWindow,
typename ScaleNWindow,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
ScaleMWindow scale_m_window,
ScaleNWindow scale_n_window)
{
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2;
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
static_assert(GetVectorSizeC() % kN2 == 0);
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
constexpr int DynamicTileOffsetFlag = 0;
auto permute_scale_n_view_1 = transform_tensor_view(
scale_n_window.get_bottom_tensor_view(),
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
number<NWave>{},
number<NPerXdl>{},
number<NRepeat>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{}));
auto permute_scale_n_view = transform_tensor_view(
permute_scale_n_view_1,
make_tuple(
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
number<NRepeat>{},
number<NWave>{},
number<NPerXdl>{}))),
make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto scale_m_window_with_dist = make_tile_window(
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view,
scale_n_window.get_window_lengths(),
scale_n_window.get_window_origin(),
o_acc_tile.get_tile_distribution());
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
using ShuffleAcc =
decltype(make_static_distributed_tensor<AccDataType>(dram_tile_distribution));
ShuffleAcc shuffle_acc[MRepeat];
auto c_out_tensor_fp32 =
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NumAccPerEpiTile, 1>{}(
[&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; });
});
static_for<0, MRepeat, 1>{}([&](auto mIter) {
auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// transpose <kM2 x NRepeat> thread matrix
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
epi_scale_m[n_idx * c_warp_y_lengths.product() + 0];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
epi_scale_m[n_idx * c_warp_y_lengths.product() + 1];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
epi_scale_m[n_idx * c_warp_y_lengths.product() + 2];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
epi_scale_m[n_idx * c_warp_y_lengths.product() + 3];
});
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleMWindow,
typename ScaleNWindow,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
ScaleMWindow scale_m_window,
ScaleNWindow scale_n_window)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
using LDSTileTensor = decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
LDSTileTensor lds_tile[2];
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto scale_m_window_with_dist = make_tile_window(
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
auto scale_n_window_with_dist = make_tile_window(
scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution());
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
constexpr int NumAccPerEpiTile =
NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product();
auto epi_tile_idx_slice =
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
return acc_tile_like_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
epi_n_idx * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
};
lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{});
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{});
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{});
static_for<0, NumAccPerEpiTile, 1>{}(
[&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; });
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr int read_stage = iAccess % 2;
constexpr int write_stage = read_stage ^ 1;
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value + 1>{});
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
if constexpr(iAccess < num_access - 1)
{
lds_tile[write_stage].get_thread_buffer() =
epi_tile_idx_slice(o_acc_tile, mIter, nIter);
epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter);
epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter);
static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) {
lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i];
});
}
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}
};
} // namespace ck_tile

View File

@@ -20,6 +20,7 @@ template <typename ADataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
index_t MWave_,
@@ -44,7 +45,7 @@ struct CShuffleEpilogueProblem
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
@@ -228,8 +229,8 @@ struct CShuffleEpilogue
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
@@ -376,9 +377,7 @@ struct CShuffleEpilogue
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
@@ -399,9 +398,10 @@ struct CShuffleEpilogue
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
// using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
// sequence<0, 1>,
// sequence<MPerIterationShuffle,
// NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
@@ -414,8 +414,7 @@ struct CShuffleEpilogue
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
@@ -423,39 +422,27 @@ struct CShuffleEpilogue
},
number<NumDTensor>{});
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
auto scale_m_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
constexpr auto idx_y_start = SFC::get_index(iAccess);
if constexpr(has_scales)
{
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
}
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
cast_lds_tile(lds_tile, in_lds_window);
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
@@ -465,8 +452,8 @@ struct CShuffleEpilogue
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
@@ -495,16 +482,16 @@ struct CShuffleEpilogue
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM,
typename ScaleN,
typename ScaleMWindow,
typename ScaleNWindow,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
ScaleM scale_m,
ScaleN scale_n)
ScaleMWindow scale_m_window,
ScaleNWindow scale_n_window)
{
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
@@ -522,9 +509,43 @@ struct CShuffleEpilogue
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
static_assert(GetVectorSizeC() % kN2 == 0);
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
constexpr int DynamicTileOffsetFlag = 0;
auto permute_scale_n_view_1 = transform_tensor_view(
scale_n_window.get_bottom_tensor_view(),
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
number<NWave>{},
number<NPerXdl>{},
number<NRepeat>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{}));
auto permute_scale_n_view = transform_tensor_view(
permute_scale_n_view_1,
make_tuple(
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
number<NRepeat>{},
number<NWave>{},
number<NPerXdl>{}))),
make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto scale_m_window_with_dist = make_tile_window(
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view,
scale_n_window.get_window_lengths(),
scale_n_window.get_window_origin(),
o_acc_tile.get_tile_distribution());
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
@@ -542,56 +563,39 @@ struct CShuffleEpilogue
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
float vec_scale_A[kM2 * MRepeat];
float vec_scale_B[NRepeat];
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
{
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
}
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
{
vec_scale_A[i * kM2 + 0] =
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 1] =
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 2] =
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 3] =
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
}
constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
});
auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NumAccPerEpiTile, 1>{}(
[&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; });
});
static_for<0, MRepeat, 1>{}([&](auto mIter) {
auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// transpose <kM2 x NRepeat> thread matrix
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
vec_scale_A[mIter * kM2 + 0];
epi_scale_m[n_idx * c_warp_y_lengths.product() + 0];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
vec_scale_A[mIter * kM2 + 1];
epi_scale_m[n_idx * c_warp_y_lengths.product() + 1];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
vec_scale_A[mIter * kM2 + 2];
epi_scale_m[n_idx * c_warp_y_lengths.product() + 2];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
vec_scale_A[mIter * kM2 + 3];
epi_scale_m[n_idx * c_warp_y_lengths.product() + 3];
});
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
@@ -615,16 +619,16 @@ struct CShuffleEpilogue
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM,
typename ScaleN,
typename ScaleMWindow,
typename ScaleNWindow,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
ScaleM scale_m,
ScaleN scale_n)
ScaleMWindow scale_m_window,
ScaleNWindow scale_n_window)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
@@ -646,21 +650,18 @@ struct CShuffleEpilogue
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
@@ -673,63 +674,32 @@ struct CShuffleEpilogue
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr int kM2 = 4; // Val
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
auto scale_m_window_with_dist = make_tile_window(
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
auto scale_n_window_with_dist = make_tile_window(
scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution());
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
float vec_scale_A[kM0 * kM2 * MRepeat];
float vec_scale_B[NRepeat];
constexpr int NumAccPerEpiTile =
NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product();
auto epi_tile_idx_slice =
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
return acc_tile_like_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
epi_n_idx * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
};
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
{
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
}
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
{
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
{
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
}
}
lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{});
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
{
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
}
});
});
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{});
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{});
static_for<0, NumAccPerEpiTile, 1>{}(
[&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; });
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr int read_stage = iAccess % 2;
@@ -747,40 +717,14 @@ struct CShuffleEpilogue
if constexpr(iAccess < num_access - 1)
{
lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter * NumMXdlPerWavePerShuffle,
nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
{
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
}
});
lds_tile[write_stage].get_thread_buffer() =
epi_tile_idx_slice(o_acc_tile, mIter, nIter);
epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter);
epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter);
static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) {
lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i];
});
}
@@ -793,8 +737,8 @@ struct CShuffleEpilogue
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);