mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix after merge ginolu/add_wgmfma_dispatcher
This commit is contained in:
@@ -107,36 +107,36 @@ CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
|
||||
{
|
||||
union U32FP162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* fp162_a;
|
||||
};
|
||||
// template <>
|
||||
// CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
|
||||
// {
|
||||
// union U32FP162_ADDR
|
||||
// {
|
||||
// uint32_t* u32_a;
|
||||
// fp16x2_t* fp162_a;
|
||||
// };
|
||||
|
||||
union U32FP162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t fp162;
|
||||
};
|
||||
// union U32FP162
|
||||
// {
|
||||
// uint32_t u32;
|
||||
// fp16x2_t fp162;
|
||||
// };
|
||||
|
||||
U32FP162_ADDR dword_addr;
|
||||
U32FP162 cur_v;
|
||||
U32FP162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.fp162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
// U32FP162_ADDR dword_addr;
|
||||
// U32FP162 cur_v;
|
||||
// U32FP162 new_;
|
||||
// uint32_t old_v, new_v;
|
||||
// dword_addr.fp162_a = p_dst;
|
||||
// cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
// do
|
||||
// {
|
||||
// old_v = cur_v.u32;
|
||||
// new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
|
||||
// new_v = new_.u32;
|
||||
// cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
// } while(cur_v.u32 != old_v);
|
||||
// }
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
|
||||
@@ -31,10 +31,9 @@ struct e8m0_bexp_t
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
|
||||
{
|
||||
data = numeric_utils<float>::get_exponent(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
|
||||
@@ -282,7 +282,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1);
|
||||
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
|
||||
|
||||
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
|
||||
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
|
||||
@@ -291,19 +291,19 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, f4x2_pk_t>)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = a_m_k_scale(m, k / ScaleBlockSize);
|
||||
auto a_scale = scale_a(m, k / ScaleBlockSize);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
aut a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<0>{}));
|
||||
auto a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
|
||||
auto a_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<1>{}));
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
@@ -315,19 +315,19 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDatatype, f4x2_pk_t>)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = b_k_n_scale(k / ScaleBlockSize, n);
|
||||
auto b_scale = scale_b(k / ScaleBlockSize, n);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
auto b_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<0>{}));
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
|
||||
auto b_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<1>{}));
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
@@ -336,7 +336,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
|
||||
ck_tile::type_convert<AccDataType>(b_k_n_scale(k / ScaleBlockSize, n));
|
||||
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
760
include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp
Normal file
760
include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp
Normal file
@@ -0,0 +1,760 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
static_assert(GetVectorSizeC() % kN2 == 0);
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// transpose <kM2 x NRepeat> thread matrix
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 2 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 3 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
|
||||
});
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
static_assert(GetVectorSizeC() % kN2 == 0);
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
auto permute_scale_n_view_1 = transform_tensor_view(
|
||||
scale_n_window.get_bottom_tensor_view(),
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{},
|
||||
number<NRepeat>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{}));
|
||||
auto permute_scale_n_view = transform_tensor_view(
|
||||
permute_scale_n_view_1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NRepeat>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view,
|
||||
scale_n_window.get_window_lengths(),
|
||||
scale_n_window.get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution());
|
||||
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
using ShuffleAcc =
|
||||
decltype(make_static_distributed_tensor<AccDataType>(dram_tile_distribution));
|
||||
ShuffleAcc shuffle_acc[MRepeat];
|
||||
auto c_out_tensor_fp32 =
|
||||
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; });
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// transpose <kM2 x NRepeat> thread matrix
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 0];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 1];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 2];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 3];
|
||||
});
|
||||
|
||||
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
using LDSTileTensor = decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
|
||||
LDSTileTensor lds_tile[2];
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(
|
||||
scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
constexpr int NumAccPerEpiTile =
|
||||
NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product();
|
||||
auto epi_tile_idx_slice =
|
||||
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
|
||||
return acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
};
|
||||
|
||||
lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{});
|
||||
|
||||
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{});
|
||||
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{});
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; });
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
constexpr int write_stage = read_stage ^ 1;
|
||||
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value + 1>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() =
|
||||
epi_tile_idx_slice(o_acc_tile, mIter, nIter);
|
||||
|
||||
epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter);
|
||||
epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter);
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) {
|
||||
lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i];
|
||||
});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -20,6 +20,7 @@ template <typename ADataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
@@ -44,7 +45,7 @@ struct CShuffleEpilogueProblem
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
@@ -228,8 +229,8 @@ struct CShuffleEpilogue
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
@@ -376,9 +377,7 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
@@ -399,9 +398,10 @@ struct CShuffleEpilogue
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
// using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
// sequence<0, 1>,
|
||||
// sequence<MPerIterationShuffle,
|
||||
// NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
@@ -414,8 +414,7 @@ struct CShuffleEpilogue
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
@@ -423,39 +422,27 @@ struct CShuffleEpilogue
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
|
||||
}
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
cast_lds_tile(lds_tile, in_lds_window);
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
@@ -465,8 +452,8 @@ struct CShuffleEpilogue
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
@@ -495,16 +482,16 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
@@ -522,9 +509,43 @@ struct CShuffleEpilogue
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
static_assert(GetVectorSizeC() % kN2 == 0);
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
auto permute_scale_n_view_1 = transform_tensor_view(
|
||||
scale_n_window.get_bottom_tensor_view(),
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{},
|
||||
number<NRepeat>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{}));
|
||||
auto permute_scale_n_view = transform_tensor_view(
|
||||
permute_scale_n_view_1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NRepeat>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view,
|
||||
scale_n_window.get_window_lengths(),
|
||||
scale_n_window.get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution());
|
||||
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
@@ -542,56 +563,39 @@ struct CShuffleEpilogue
|
||||
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
}
|
||||
constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
|
||||
});
|
||||
auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; });
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// transpose <kM2 x NRepeat> thread matrix
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
vec_scale_A[mIter * kM2 + 0];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 0];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
vec_scale_A[mIter * kM2 + 1];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 1];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
vec_scale_A[mIter * kM2 + 2];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 2];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
vec_scale_A[mIter * kM2 + 3];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 3];
|
||||
});
|
||||
|
||||
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
@@ -615,16 +619,16 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
@@ -646,21 +650,18 @@ struct CShuffleEpilogue
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
@@ -673,63 +674,32 @@ struct CShuffleEpilogue
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(
|
||||
scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
constexpr int NumAccPerEpiTile =
|
||||
NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product();
|
||||
auto epi_tile_idx_slice =
|
||||
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
|
||||
return acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
};
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
}
|
||||
}
|
||||
lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{});
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{});
|
||||
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{});
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; });
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
@@ -747,40 +717,14 @@ struct CShuffleEpilogue
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter * NumMXdlPerWavePerShuffle,
|
||||
nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
}
|
||||
});
|
||||
lds_tile[write_stage].get_thread_buffer() =
|
||||
epi_tile_idx_slice(o_acc_tile, mIter, nIter);
|
||||
|
||||
epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter);
|
||||
epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter);
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) {
|
||||
lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i];
|
||||
});
|
||||
}
|
||||
|
||||
@@ -793,8 +737,8 @@ struct CShuffleEpilogue
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
@@ -36,17 +36,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
|
||||
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
using KThreadPerXdl = get_warp_size() / MThreadPerXdl;
|
||||
using BlockGemm = remove_cvref_t<typename MXFlatmmPipeline_::BlockGemm>;
|
||||
static constexpr int MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
static constexpr int NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
|
||||
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = remove_cvref_t<typename FlatmmPipeline::MXdlPack>;
|
||||
static constexpr int NXdlPack = remove_cvref_t<typename FlatmmPipeline::NXdlPack>;
|
||||
static constexpr int KXdlPack = remove_cvref_t<typename FlatmmPipeline::KXdlPack>;
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -55,6 +55,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -76,7 +77,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
constexpr int block_size = FlatmmPipeline::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
@@ -86,7 +87,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmPipeline,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
@@ -201,7 +202,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
|
||||
// A scale tensor view
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
@@ -215,7 +216,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
scale_a_naive_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_tuple(kargs.M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl), KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -397,8 +398,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// 32>{}),
|
||||
// {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
// auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
|
||||
@@ -158,7 +158,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
// static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
|
||||
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
// static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -178,7 +178,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread;
|
||||
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
@@ -631,7 +631,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionB
|
||||
{
|
||||
V4UInt_Buffer u = 0;
|
||||
@@ -718,24 +718,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -793,23 +793,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -825,7 +825,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -850,7 +850,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -888,7 +888,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
@@ -905,8 +905,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
@@ -922,23 +922,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -979,7 +979,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1017,7 +1017,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
@@ -1036,15 +1036,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -1055,23 +1054,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1082,7 +1081,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -1107,7 +1106,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1181,7 +1180,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1224,7 +1223,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -1249,7 +1248,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1297,8 +1296,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramblockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramblockWindowTmp& scale_b_flat_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
|
||||
@@ -13,7 +13,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
// static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
// static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
@@ -35,10 +35,10 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
@@ -117,9 +117,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr index_t MXdlPack = Problm::MXdlPack;
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
@@ -209,24 +208,24 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kMPerBlock = tileShape::BlockTile::at(I0);
|
||||
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = M_Lanes;
|
||||
static constexpr index_t Y1 = M_Warps;
|
||||
static constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
constexpr index_t Y2 = M_Lanes;
|
||||
constexpr index_t Y1 = M_Warps;
|
||||
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
|
||||
@@ -246,24 +245,24 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kNPerBlock = tileShape::BlockTile::at(I1);
|
||||
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Lanes = 64 / N_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = N_Lanes;
|
||||
static constexpr index_t Y1 = N_Warps;
|
||||
static constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
constexpr index_t Y2 = N_Lanes;
|
||||
constexpr index_t Y1 = N_Warps;
|
||||
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<M_Warps>, // ?
|
||||
@@ -284,19 +283,19 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr index_t MWavePerBlk = M_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -314,14 +313,14 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,8 +14,4 @@ if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(test_ck_tile_fp8 PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
|
||||
endif()
|
||||
Reference in New Issue
Block a user