Merge branch 'develop' into users/yiding12/fmha-bwd-workspace

This commit is contained in:
Yi DING
2026-04-22 15:17:57 +08:00
committed by GitHub
537 changed files with 26614 additions and 21905 deletions

View File

@@ -19,19 +19,29 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale.hpp"
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"

View File

@@ -2166,27 +2166,11 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
}
else if constexpr(N == 8)
{
#if 0
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(fp16_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16

View File

@@ -1992,27 +1992,11 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
}
else if constexpr(N == 8)
{
#if 0
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(fp16_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
@@ -87,7 +89,7 @@ namespace ck_tile::core::arch::mma {
*
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
* with a Num Access value of at least 2.
* with a Num Access value which is a multiple of 2.
*
* (load / store manipulation). It seems like the load and store tile functions end up looking for
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
@@ -102,13 +104,16 @@ namespace ck_tile::core::arch::mma {
*
* -- CMPerLane --
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge. This
* does not count a potential increased M dimension size from block hiding. In this case, we have M
* = kCMBlock * M2 * M1 * M0 instead.
*
* -- CNumAccess --
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
* and will not try to request a specific value. Absolutely needed for logical correctness of
* register mappings since we can not perform arbitrary M permutations without messing up the A
* layout.
* layout. This does not count a potential increased M dimension size from block hiding. In this
* case, we have M = kCMBlock * M2 * M1 * M0 instead.
*/
/**
@@ -144,7 +149,7 @@ struct amdgcn_mma_base
using CDataType = CDataType_;
// Fragment (MmaTile) sizes, check description above.
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
static constexpr index_t kM = FragM; // M = M2 * M1 * M0 (* kCMBlocks when block-hiding)
static constexpr index_t kN = FragN;
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
@@ -157,15 +162,37 @@ struct amdgcn_mma_base
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
// K-dimension compression ratio for A matrix, always 2 for sparse intrinsics.
static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1;
// Layout checks
static_assert(kK % kABKPerLane == 0);
static_assert(kABKPerLane % kAKNumAccess == 0);
static_assert(kABKPerLane % kBKNumAccess == 0);
static_assert(kCMPerLane % kCMNumAccess == 0);
// Register types (derived)
static constexpr index_t WaveSize = WaveSize_;
static_assert((kM * kK * kARepeat) % WaveSize == 0);
static_assert((kM * kK * kARepeat) % (WaveSize * kCompressionRatio) == 0);
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
static_assert((kM * kN) % WaveSize == 0);
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize / kCompressionRatio>;
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
// Block-hiding / repeat related traits (derived)
static_assert(kARepeat == kBRepeat || !std::is_same_v<OpType, WmmaOp>);
static_assert(kARepeat == 1 || kBRepeat == 1 || !std::is_same_v<OpType, MfmaOp>);
static constexpr index_t kCMBlocks = std::is_same_v<OpType, MfmaOp> ? kBRepeat : 1;
static constexpr index_t kCNBlocks = std::is_same_v<OpType, MfmaOp> ? kARepeat : 1;
static_assert(kM % (kCMBlocks * kCMPerLane) == 0);
static_assert(kN % kCNBlocks == 0);
// For the C matrix, the block dimension B is either put in the Vector dimension or the Lane
// dimension. We can tell which by checking if we get the right Vector size.
static constexpr bool CBlockDimInVecDim =
kCMBlocks * kCNBlocks * kCMPerLane == vector_traits<CVecType>::vector_size;
};
/**
@@ -177,15 +204,30 @@ struct Unsupported;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept HasExecSignature
* @brief Helper concept for exec signature check.
*/
template <typename MmaOp, typename... ExecArgs>
concept HasExecSignature = requires {
{
MmaOp::exec(typename MmaOp::AVecType{},
typename MmaOp::BVecType{},
typename MmaOp::CVecType{},
std::declval<ExecArgs>()...)
} -> std::convertible_to<typename MmaOp::CVecType>;
};
/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
*/
// TODO: Make sure this actually matches amdgcn_mma.
template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
typename MmaOp::OpFamily;
{ MmaOp::OpFamily } -> std::convertible_to<MmaOpFamily>;
// Captures types for inputs / outputs to mma function
typename MmaOp::ADataType;
@@ -194,7 +236,6 @@ concept MmaOpI = requires(MmaOp op) {
typename MmaOp::AVecType;
typename MmaOp::BVecType;
typename MmaOp::CVecType;
// Captures CK-specific layout properties
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
@@ -203,13 +244,8 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
// Static exec function
{
MmaOp::exec(
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
} -> std::convertible_to<typename MmaOp::CVecType>;
};
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -248,7 +284,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
// clang-format on
{
// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
CK_TILE_DEVICE static auto
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
// Prints once across all thread blocks and threads.
@@ -267,6 +303,8 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
#pragma clang diagnostic pop
// Include the implementations
#include "wmma/wmma.hpp"
#include "wmma/wmma.hpp" // should be included before the below headers
#include "mfma/mfma.hpp"
#include "scale/scale.hpp"
#include "sparse/sparse.hpp"

View File

@@ -51,6 +51,82 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, 64u, 4, 1, 1, 1, 2, 16, 4, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, 64u, 4, 1, 2, 1, 1, 16, 4, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, 64u, 4, 1, 1, 1, 16, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, 64u, 4, 1, 16, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets

View File

@@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx9;
};

View File

@@ -1,230 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma
{
using FragWiseMmaOp = MmaOp;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using ABufferType = AVecType[FragsM][FragsK];
using BBufferType = BVecType[FragsN][FragsK];
using CBufferType = CVecType[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT const&>(inputBuffer);
}
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input WaveTiles to the native vector types
// required by the FragWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT&>(inputBuffer);
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input WaveTiles,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
// types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output WaveTile format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
public:
/*! @brief Forward to Mma operation with specified accumulation order.
* @tparam VecTA The input WaveTile A vector type
* @tparam VecTB The input WaveTile B vector type
* @tparam VecTC The input/output WaveTile C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
return exec_row_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
return exec_col_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,343 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaPipelineOptionFlag
* @brief Individual option flags for configuring MmaPipeline behavior.
*/
enum struct MmaPipelineOptionFlag : unsigned
{
NONE = 0x0, ///< No flags set
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
};
/**
* @struct MmaPipelineOptionFlags
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
* and querying pipeline option flags.
*/
struct MmaPipelineOptionFlags
{
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
{
}
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
: mFlags(original.mFlags)
{
}
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
{
mFlags |= toType(addValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
{
MmaPipelineOptionFlags result(*this);
result |= addValue;
return result;
}
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
{
mFlags &= toType(maskValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
{
MmaPipelineOptionFlags result(*this);
result &= maskValue;
return result;
}
constexpr MmaPipelineOptionFlags operator~() const
{
MmaPipelineOptionFlags result(*this);
result.mFlags = ~result.mFlags;
return result;
}
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
{
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
}
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
private:
Type mFlags;
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
};
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
{
return rhs == lhs;
}
/**
* @class MmaPipelineBase
* @brief CRTP base class that implements the common Mma pipeline logic shared by
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
*
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
* pipeline behavior (e.g., C transposition, A compression).
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
* - Type aliases: @c InternalAVecT, @c InternalBVecT, @c InternalCVecT,
* @c CVecType, @c MmaOp
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform,
* @c DTransform
* - A static @c execImpl(std::tuple<A,B,C>&) method.
*
* @par The pipeline performs the following steps in @c exec():
* 1. Apply pre-transforms and format input buffers (A, B, C).
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
* 3. Apply post-transform and format the output buffer (D) back to the user type.
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
*/
// TODO: c++20: use MmaPipelineOptionFlags directly
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
struct MmaPipelineBase
{
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
private:
/**
* @brief Reconstruct a tuple with its first element passed through @c formatBuffer
* while preserving all remaining elements unchanged.
* @tparam DstT Target type for the formatted first element.
* @tparam SrcT Forwarding-reference type of the input tuple.
* @tparam Is Index pack for elements 1..N-1 of the tuple.
* @param inputTuple The source tuple whose first element will be formatted.
* @return A new tuple with the formatted first element and the remaining elements forwarded.
*/
template <typename DstT, typename SrcT, std::size_t... Is>
CK_TILE_DEVICE static auto formatBufferTupleImpl(SrcT&& inputTuple, std::index_sequence<Is...>)
{
auto&& first_elem = std::get<0>(std::forward<SrcT>(inputTuple));
using FirstElemResultType =
decltype(formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)));
using InputTupleType = ck_tile::remove_cvref_t<SrcT>;
return std::tuple<FirstElemResultType, std::tuple_element_t<Is + 1, InputTupleType>...>(
formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)),
std::get<Is + 1>(std::forward<SrcT>(inputTuple))...);
}
/**
* @brief Format (reinterpret-cast) a buffer to the hardware-native vector type @p DstT.
*
* Three cases are handled:
* - **Tuple**: recursively format the first element via @c formatBufferTupleImpl,
* preserving any metadata in the remaining tuple elements.
* - **Array / Pointer**: forwarded unchanged.
* - **Scalar / Vector**: reinterpret-cast to @p DstT (sizes must match).
*
* @tparam DstT The target hardware vector type.
* @tparam SrcT Forwarding-reference type of the input buffer.
* @param inputBuffer The buffer to format.
* @return A reference (or value) of type @p DstT corresponding to @p inputBuffer.
*/
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static decltype(auto) formatBuffer(SrcT&& inputBuffer)
{
using DecayedSrcT = ck_tile::remove_cvref_t<SrcT>;
// If SrcT is a tuple, extract the first element (the vector) and format it
// while preserving all remaining elements (metadata)
if constexpr(is_std_tuple_v<DecayedSrcT>)
{
// Create index sequence for all remaining elements (skip first)
constexpr std::size_t tuple_size = std::tuple_size_v<DecayedSrcT>;
return formatBufferTupleImpl<DstT>(std::forward<SrcT>(inputBuffer),
std::make_index_sequence<tuple_size - 1>{});
}
else if constexpr(std::is_array_v<DecayedSrcT> || std::is_pointer_v<DecayedSrcT>)
{
return std::forward<SrcT>(inputBuffer);
}
else
{
static_assert(sizeof(DstT) == sizeof(DecayedSrcT), "Size mismatch in formatBuffer");
using QualifiedDstT =
std::conditional_t<std::is_const_v<DecayedSrcT>, DstT const, DstT>;
return reinterpret_cast<QualifiedDstT&>(inputBuffer);
}
}
protected:
/** @brief Query whether a specific @ref MmaPipelineOptionFlag is set. */
template <MmaPipelineOptionFlag Flag>
constexpr CK_TILE_DEVICE static bool hasFlag()
{
return Flags.testFlag(Flag);
}
/**
* @brief Apply a transform **then** format the result to @p DstT.
* Used for input operands (A, B, C) before the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto preApplyTransform(Args&&... args)
{
return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...));
}
/**
* @brief Format a buffer to @p DstT **then** apply a transform.
* Used for the output operand (D) after the mma loop.
*/
template <typename DstT, typename Transform, typename... Args>
CK_TILE_DEVICE static auto postApplyTransform(Args&&... args)
{
return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...));
}
/**
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
* @return A @c std::tuple of the transformed (A, B, C, [scaleA, scaleB]) vectors ready for the
* mma loop.
*/
template <typename ATransformInputs,
typename BTransformInputs,
typename CTransformInputs,
typename... ExtraArgs>
CK_TILE_DEVICE static decltype(auto) applyTransformsToInputs(ATransformInputs&& a,
BTransformInputs&& b,
CTransformInputs&& accum,
ExtraArgs&&... extras)
{
using InternalAVecT = typename Derived::InternalAVecT;
using InternalBVecT = typename Derived::InternalBVecT;
using InternalCVecT = typename Derived::InternalCVecT;
using ATransform = typename Derived::ATransform;
using BTransform = typename Derived::BTransform;
using CTransform = typename Derived::CTransform;
return std::make_tuple(
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)),
std::forward<ExtraArgs>(extras)...);
}
/**
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
* @param c_result The accumulator to post-process.
* @return The final D output in the user-facing vector type.
*/
template <typename CTransformResult>
CK_TILE_DEVICE static auto applyTransformToOutput(CTransformResult&& c_result)
{
static_assert(!is_std_tuple_v<decltype(c_result)>,
"If CTransform returns more than the vector, update this function.");
using CVecT = typename Derived::CVecType;
using DTransform = typename Derived::DTransform;
return postApplyTransform<CVecT, DTransform>(c_result);
}
public:
/**
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
* @tparam VecTA Type of the A WaveTile buffer.
* @tparam VecTB Type of the B WaveTile buffer.
* @tparam VecTC Type of the C (accumulator) WaveTile buffer.
* @param a Input WaveTile A.
* @param b Input WaveTile B.
* @param accum Input/output accumulator WaveTile C.
* @return The output WaveTile D after accumulation and post-transform.
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
Derived::execImpl(transformed_inputs);
auto&& [a_result, b_result, c_result] = std::move(transformed_inputs);
return applyTransformToOutput(std::move(c_result));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
}
}
template <typename VecTA,
typename VecTB,
typename VecTC,
typename ScaleADataType,
typename ScaleBDataType>
CK_TILE_DEVICE static decltype(auto)
exec(VecTA&& a, VecTB&& b, VecTC&& accum, ScaleADataType&& scale_A, ScaleBDataType&& scale_B)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
Derived::execImpl(transformed_inputs);
auto&& [a_result, b_result, c_result, scale_A_result, scale_B_result] =
std::move(transformed_inputs);
return applyTransformToOutput(std::move(c_result));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
}
}
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
/**
* @concept MmaPipelineI
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
*/
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
// Include the implementations
#include "wmma/wmma_selector.hpp"
#include "mfma/mfma_selector.hpp"
#include "sparse/sparse_selector.hpp"

View File

@@ -18,6 +18,18 @@ struct PassThroughTransform
}
};
/**
* @struct MmaDefaultPassThroughTransforms
* @brief Implements the default MMA transforms
*/
struct MmaDefaultPassThroughTransforms
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @class MmaTransformsDefaultSelector
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
@@ -27,7 +39,10 @@ struct PassThroughTransform
*/
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
struct MmaTransformsDefaultSelector
{
using SelectedTransforms = MmaDefaultPassThroughTransforms;
};
#if CK_TILE_CONCEPTS

View File

@@ -0,0 +1,177 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_pipeline.hpp"
#include "mma_selector.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
#include <tuple>
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major fragment order
ROW_MAJOR,
// Decomposition and accumulation in col-major fragment order
COL_MAJOR
};
namespace dense::wavewise::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool SwapAB>
constexpr inline int getPipelineFlags()
{
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
}
} // namespace dense::wavewise::detail
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile K dimension
* @tparam WaveTileK Mma WaveTile M dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam SwapAB Swaps A and B input vectors
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaOpFamily OpFamily,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool SwapAB = false,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_;
// Fragment dimensions
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
constexpr static uint32_t FragsC = FragsM * FragsN;
// Vector types for packed registers in each fragment
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using AVecType = InternalAVecT[FragsM][FragsK];
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static void execImpl(std::tuple<VecTA, VecTB, VecTC>& vecs)
{
auto& [a_frag, b_frag, c_frag] = vecs;
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
// "Row-major" accumulation over the N-dimension fragments first.
// Pseudo code here, but we would basically iterate over the fragments in row-major
// order. We also have to ensure that the incoming vector WaveTiles are converted to
// native vector types before passing to the FragWiseMma exec function.
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
// "Col-major" accumulation over the M-dimension fragments first.
// Pseudo code here, but we would basically iterate over the blocks in col-major order
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] =
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
}
else
{
static_assert(false);
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,229 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,149 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
/**
* @class ScaleMfmaDefaultSelector
* @brief Implements a default scale MFMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam WaveTileKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
// is_power_of_two_integer(WaveTileKTest))
struct ScaleMfmaDefaultSelector
{
private:
// Define our candidate MFMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultScaleMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SCALE>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
CandidateOp,
amdgcn_mma<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the CDNA default MMA selector strategy for scale MFMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
amdgcn_target_id::GFX950)>,
std::enable_if_t<OpFamily == MmaOpFamily::SCALE>>>
{
private:
// Provide the default depth-K search strategy for each class of common MFMA shapes.
// Start searching from the largest K dimension MFMA shape down to the smallest.
using CandidateOp16x16 = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
128u,
CompilerTarget>::SelectedOp;
using CandidateOp32x32 = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
32u,
32u,
64u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
1u,
1u,
1u,
CompilerTarget>::SelectedOp;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
static constexpr bool IsSupported32x32 =
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
using SelectedOp =
std::conditional_t<IsSupported32x32,
CandidateOp32x32,
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,10 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// Include scale MFMA traits and architecture-specific implementations
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/config.hpp"
#include <cstdint>
#include <tuple>
#include <type_traits>
#include <utility>
namespace ck_tile::core::arch::mma {
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t FragM,
std::uint32_t FragN,
std::uint32_t FragK,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp_ = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SCALE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Expose caller-side vector types
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
template <typename VecTA,
typename VecTB,
typename VecTC,
typename ScaleADataType,
typename ScaleBDataType>
CK_TILE_DEVICE static void
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
{
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
// #include "ck_tile/core/numeric/pk_fp6.hpp"
#include <cstdint>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
namespace ck_tile::core::arch::mma {
namespace scale::detail {
template <typename T>
struct ScaleDataTypeToFlag;
template <>
struct ScaleDataTypeToFlag<fp8_t> // e4m3
{
static constexpr std::int32_t value = 0;
};
template <>
struct ScaleDataTypeToFlag<bf8_t> // e5m2
{
static constexpr std::int32_t value = 1;
};
// template <>
// struct ScaleDataTypeToFlag<pk_fp6_t<1>> // e2m3
// {
// static constexpr std::int32_t value = 2;
// };
// template <>
// struct ScaleDataTypeToFlag<bf6_t> // e3m2
// {
// static constexpr std::int32_t value = 3;
// };
template <>
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
{
static constexpr std::int32_t value = 4;
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept ScaleMfmaDataTypeToFlag
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
*/
template <typename DataTypeToFlag>
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
// Flag members for scale MFMA instructions
{ DataTypeToFlag::value } -> std::convertible_to<int>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
template <typename T>
inline constexpr std::int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
} // namespace scale::detail
struct DefaultScaleMfmaCtrlFlags
{
static constexpr std::int32_t OPSEL_A = 0;
static constexpr std::int32_t OPSEL_B = 0;
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept ScaleMfmaCtrlFlags
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
*/
template <typename CtrlFlags>
concept ScaleMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
// Flag members for scale MFMA instructions
{ CtrlFlags::OPSEL_A } -> std::convertible_to<int>;
{ CtrlFlags::OPSEL_B } -> std::convertible_to<int>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include <type_traits>
namespace ck_tile::core::arch::mma {
/**
* @struct MmaDefaultTransformsScale
* @brief Implements the default MMA transforms for Scale
*/
struct MmaDefaultTransformsScale
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @struct MmaTransformsDefaultSelector
* @brief Specialization for Scale MFMA transforms
* Provides default transform selection for scale operations
*
* @tparam MmaOp Scale MMA operation
* @tparam CompilerTarget The compiler target
*/
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_mma_op_scale(MmaOp))
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SCALE>>
{
using SelectedTransforms = MmaDefaultTransformsScale;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -31,25 +30,12 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
{
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 4);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
}
};

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
constexpr inline int getPipelineFlags()
{
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A);
}
} // namespace sparse::detail
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
"Cannot transpose C in sparse intrinsics.");
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Calculate the uncompressed A vector type
struct ExternalAVecCalculator
{
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
};
// Expose caller-side vector types
using AVecType = typename ExternalAVecCalculator::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static void
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
{
checkATransformResult<ATransformResult>();
auto& [a_result, b_vec, c_vec] = vecs;
auto& [a_vec, idx] = a_result;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
}
private:
// Type check helper - not a device function, so std::declval is available
template <typename ATransformResult>
static constexpr void checkATransformResult()
{
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
static_assert(std::is_same_v<ATransformResult,
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -43,18 +43,15 @@ struct BuiltinParams
template <SparseCompressionIndex Idx>
static constexpr BuiltinParams getBuiltinParams()
{
BuiltinParams params;
// TODO c++20: designated initializers
if constexpr(Idx == SparseCompressionIndex::FIRST)
{
params.UseFirstIndex = 1;
params.ByteIndexToOverride = 0;
return BuiltinParams{1, 0};
}
else
{
params.UseFirstIndex = 0;
params.ByteIndexToOverride = static_cast<int>(Idx);
return BuiltinParams{0, static_cast<int>(Idx)};
}
return params;
}
} // namespace sparse::detail

View File

@@ -6,22 +6,101 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
/**
* @struct MmaDefaultTransformsSparse
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 4, 1>{}([&](auto j) {
if(static_cast<float>(a_vec[i * 4 + j]) != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
} // namespace sparse::detail
/**
* @class SparseCompressTransform
* @brief Performs 2:4 structured sparsity compression to the vector v and produces an index mask.
* @note Returns a tuple of two. The first element is the vector v with the same scalar type but
* its size halved. The second element is the index mask.
*/
template <index_t CompressionRatio>
struct SparseCompressTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType& v)
{
using VecTraits = vector_traits<remove_cvref_t<VecType>>;
using ScalarT = typename VecTraits::scalar_type;
static constexpr auto VecN = VecTraits::vector_size;
static constexpr index_t CompressedSize = VecN / CompressionRatio;
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
// TODO c++20: Use bit_cast
return std::tuple<VecCompressed&, int32_t>(
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
}
};
/**
* @class MmaDefaultTransformsSparse
* @brief Implements the default transforms for Sparse
*
* For 2:4 structured sparsity with inline register metadata:
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
* - ATransform: 2:4 structured sparsity compression
* - BTransform: Pass-through (sparse operands already formatted)
* - CTransform: Pass-through (input accumulator)
* - DTransform: Pass-through (output accumulator as-is)
*/
template <index_t CompressionRatio>
struct MmaDefaultTransformsSparse
{
using ATransform = PassThroughTransform;
using ATransform = SparseCompressTransform<CompressionRatio>;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
@@ -42,7 +121,7 @@ struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
{
using SelectedTransforms = MmaDefaultTransformsSparse;
using SelectedTransforms = MmaDefaultTransformsSparse<MmaOp::kCompressionRatio>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -7,7 +7,6 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -21,23 +20,9 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
{
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
static constexpr index_t kCompressionRatio = 2;
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
static_assert(CompressedSize == 8);
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
// and evaluate changing this to a transform at a higher level.
// aVec not being const can cause problems when running multiple intrinsics.
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
const AVecCompressed a_vec_pruned = {
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
}
};

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class TileDistrEncCalc
* @brief Given an MmaOp and modifiers, provides warp-level tile distribution encodings for mapping
* ABC matrix fragment coordinates to register coordinates (lane, vector item) and vice versa.
* @tparam MmaOp Intrinsic (amdgcn_mma).
* @tparam CTranspose Whether we are using CTranspose.
* @tparam SFactor Swizzle factor. Not implemented.
* @tparam AttrNumAccessA Requested NumAccess for the A matrix. Must be multiple of "fundamental"
* NumAccess for intrinsic. See details in amdgcn_mma.hpp.
* @tparam AttrNumAccessB Requested NumAccess for the B matrix.
*/
template <typename MmaOp,
bool CTranspose = false,
index_t SFactor = 1,
index_t AttrNumAccessA = MmaOp::kAKNumAccess,
index_t AttrNumAccessB = MmaOp::kBKNumAccess>
struct TileDistrEncCalc
{
private:
static constexpr index_t NumAccessA = std::max(MmaOp::kAKNumAccess, AttrNumAccessA);
static constexpr index_t NumAccessB = std::max(MmaOp::kBKNumAccess, AttrNumAccessB);
// We are free to choose any NumAccess value to manipulate the load / store behavior, unless the
// intrinsic fundamentally requires a base NumAccess factor for the layout to be correct.
static_assert(AttrNumAccessA % MmaOp::kAKNumAccess == 0,
"Requesting NumAccessA incompatible with builtin.");
static_assert(AttrNumAccessB % MmaOp::kBKNumAccess == 0,
"Requesting NumAccessB incompatible with builtin.");
static_assert(MmaOp::kABKPerLane % NumAccessA == 0);
static_assert(MmaOp::kABKPerLane % NumAccessB == 0);
static_assert(SFactor == 1, "Swizzle not implemented yet."); // TODO: Implement Swizzle.
template <index_t MajorDimSize, index_t Repeat, index_t NumAccess, index_t CompressionRatio = 1>
using ABWarpDstrEnc = tile_distribution_encoding<
sequence<Repeat>,
tuple<sequence<MajorDimSize>,
sequence<NumAccess,
MmaOp::kK / MmaOp::kABKPerLane,
MmaOp::kABKPerLane / NumAccess / CompressionRatio>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
static constexpr auto get_cwarp_dstr_encoding()
{
// We unmerge the M and N dimensions in the same way every time.
using MSubDims = sequence<MmaOp::kCMBlocks,
MmaOp::kCMNumAccess,
MmaOp::kM / MmaOp::kCMBlocks / MmaOp::kCMPerLane,
MmaOp::kCMPerLane / MmaOp::kCMNumAccess>;
using NSubDims = sequence<MmaOp::kCNBlocks, MmaOp::kN / MmaOp::kCNBlocks>;
// In case of CTranspose, all we do is swap the M and N dimension.
using MatDims =
std::conditional_t<CTranspose, tuple<NSubDims, MSubDims>, tuple<MSubDims, NSubDims>>;
constexpr int MInx = CTranspose ? 2 : 1;
constexpr int NInx = CTranspose ? 1 : 2;
// For MFMA intrinsics with blocks, the block dimensions might be in the Lane dim or in the
// Vec dim, so we get different merge orderings.
if constexpr(MmaOp::CBlockDimInVecDim)
{
return tile_distribution_encoding<sequence<1>,
MatDims,
tuple<sequence<MInx, NInx>>,
tuple<sequence<2, 1>>,
sequence<MInx, NInx, MInx, MInx>,
sequence<0, 0, 1, 3>>{};
}
else
{
return tile_distribution_encoding<sequence<1>,
MatDims,
tuple<sequence<MInx, NInx, MInx, NInx>>,
tuple<sequence<0, 0, 2, 1>>,
sequence<MInx, MInx>,
sequence<1, 3>>{};
}
}
using AEnc_ = ABWarpDstrEnc<MmaOp::kM, MmaOp::kARepeat, NumAccessA, MmaOp::kCompressionRatio>;
using BEnc_ = ABWarpDstrEnc<MmaOp::kN, MmaOp::kBRepeat, NumAccessB>;
public:
// When using CTranspose, the A and B matrices are swapped.
using AWarpDstrEncoding = std::conditional_t<CTranspose, BEnc_, AEnc_>;
using BWarpDstrEncoding = std::conditional_t<CTranspose, AEnc_, BEnc_>;
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// Some additional consistency checks
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::AVecType>::vector_size);
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::BVecType>::vector_size);
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_vector_items ==
vector_traits<typename MmaOp::CVecType>::vector_size);
};
} // namespace ck_tile::core::arch::mma

View File

@@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx11_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx11;
};
@@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector<MmaOp,
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx12_t<CompilerTarget>>
struct MmaTransformsDefaultSelector<
MmaOp,
CompilerTarget,
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
{
using SelectedTransforms = MmaDefaultTransformsGfx12;
};

View File

@@ -74,6 +74,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_CNAN 5
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD
@@ -209,6 +210,17 @@
#endif
#endif
// workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm >= 7.12)
// Philox RNG VGPR parameters get aliased under high register pressure (d256 tile).
// fp16 is affected; bf16 is not (different type conversion codegen path).
#ifndef CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE
#if(HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR >= 12) || (HIP_VERSION_MAJOR > 7)
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 1
#else
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif

View File

@@ -84,19 +84,6 @@ struct array
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
@@ -247,13 +234,6 @@ CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&...
return {std::forward<Ts>(ts)...};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
template <typename T, index_t Size>

View File

@@ -480,32 +480,6 @@ struct sequence_split
using right_type = decltype(Seq::extract(range1{}));
};
#if 0
// reverse sequence
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
struct sequence_reverse<sequence<I>>
{
using type = sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<sequence<I0, I1>>
{
using type = sequence<I1, I0>;
};
#endif
namespace detail {
template <typename Id, index_t... Ns>
struct seq_reverse;

View File

@@ -24,18 +24,4 @@ using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
} // namespace ck_tile

View File

@@ -23,18 +23,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template<typename T_, index_t N_>
struct thread_buffer {
@@ -103,25 +91,6 @@ struct thread_buffer {
return vx.data;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \

View File

@@ -292,9 +292,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
// below function should be used under tuple_array<> type, no extra check will perform here
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
@@ -333,13 +330,6 @@ struct vector_traits<tuple<T...>, void>
static constexpr index_t vector_size = sizeof...(T);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{

View File

@@ -22,7 +22,8 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
rta_asm, // round to nearest away
standard_cnan, // rtn with canonical NaN
};
template <bf16_rounding_mode rounding =
@@ -226,6 +227,39 @@ uint16_t float_to_bf16_rta_asm(float f)
return u.hi;
}
CK_TILE_HOST_DEVICE
constexpr bool float_is_nan_raw(float f)
{
#if defined(__has_builtin) && __has_builtin(__builtin_isnan)
return __builtin_isnan(f);
#else
uint32_t bits = bit_cast<uint32_t>(f);
constexpr uint32_t exp_mask = 0x7f800000;
constexpr uint32_t mant_mask = 0x007fffff;
return (bits & exp_mask) == exp_mask && (bits & mant_mask);
#endif
}
// Round to nearest even, but canonicalize any NaN input to the canonical quiet bf16 NaN
// (`0x7fff`). Unlike `float_to_bf16_rtn_raw`, this does not preserve signaling NaN
// payload/state.
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_cnan_raw(float f)
{
#if defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && __FINITE_MATH_ONLY__)
// Fast/finite-math can fold the NaN predicate away, so fall back to standard RTN.
return float_to_bf16_rtn_raw(f);
#else
// `-fgpu-flush-denormals-to-zero` only affects denormals, not NaN handling.
uint32_t bits = bit_cast<uint32_t>(f);
uint32_t tmp = (bits >> 16) & 1;
uint32_t res = float_is_nan_raw(f) ? 0x7fff0000 : bits + tmp + 0x7fff;
return uint16_t(res >> 16);
#endif
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
@@ -249,6 +283,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::standard_cnan)
return float_to_bf16_rtn_cnan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)

View File

@@ -264,93 +264,6 @@ bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.t
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
#endif

View File

@@ -73,27 +73,6 @@ struct numeric<int8_t>
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }

View File

@@ -295,10 +295,6 @@ struct tile_sweeper
F f;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template <typename T,
typename F,

View File

@@ -236,12 +236,13 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
namespace detail {
template <typename Lengths, typename Strides, index_t I, typename AccOld>
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
long_index_t acc_new = acc_old + static_cast<long_index_t>(lengths[i] - number<1>{}) *
static_cast<long_index_t>(strides[i]);
if constexpr(i.value < Lengths::size() - 1)
{
@@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size =
const long_index_t element_space_size_long =
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
constexpr long_index_t element_space_size_clamp_value =
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
const index_t element_space_size =
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
@@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = detail::calculate_element_space_size_impl(
const auto element_space_size_long = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
constexpr long_index_t element_space_size_clamp_value =
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
const index_t element_space_size =
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));

View File

@@ -454,45 +454,6 @@ struct tile_distribution_detail
} // namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)

View File

@@ -209,4 +209,21 @@ template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
/**
* @brief Type trait to detect whether a type is a @c std::tuple specialization.
* @tparam T The type to inspect.
*/
template <typename T>
struct is_std_tuple : std::false_type
{
};
template <typename... Args>
struct is_std_tuple<std::tuple<Args...>> : std::true_type
{
};
template <typename T>
static constexpr bool is_std_tuple_v = is_std_tuple<T>::value;
} // namespace ck_tile

View File

@@ -25,6 +25,7 @@ template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name
template <> struct DataTypeTraits<pk_fp6x16_t> { static constexpr const char * name = "pk_fp6x16"; };
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
template <> struct DataTypeTraits<e8m0_t> { static constexpr const char * name = "e8m0"; };
template <> struct DataTypeTraits<ck_tile::tf32_t>{ static constexpr const char* name = "tf32"; };
template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };

View File

@@ -27,10 +27,13 @@ struct ElementWiseKernel
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(const Dims lens,
const Dims input_strides,
const Dims output_strides,
template <typename... XDataType,
typename DimsLens,
typename DimsInStrides,
typename DimsOutStrides>
CK_TILE_DEVICE void operator()(const DimsLens lens,
const DimsInStrides input_strides,
const DimsOutStrides output_strides,
const tuple<XDataType...>& input_tensors,
YDataType* p_y) const
{
@@ -49,10 +52,11 @@ struct ElementWiseKernel
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
const auto transformed_tensor = pad_tensor_view(
transform_tensor_view(tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
transform_tensor_view(
tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<DimsLens::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
@@ -86,13 +90,14 @@ struct ElementWiseKernel
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, lens, output_strides, number<S::kVectorM>{});
const auto transformed_y_m_n = pad_tensor_view(
transform_tensor_view(y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
const auto transformed_y_m_n =
pad_tensor_view(transform_tensor_view(
y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<DimsOutStrides::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
auto y_window = make_tile_window(transformed_y_m_n,
make_tuple(number<S::kBlockM>{}),

View File

@@ -745,14 +745,6 @@ struct PassThroughPack2
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
#if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<fp16x2_t>(t);
}
#endif
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
{
uint8_t x_u8 = bit_cast<uint8_t>(x);
@@ -871,61 +863,6 @@ struct UnaryConvert
}
};
#if 0
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
static constexpr const char* name = "Scale";

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/permuten_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -33,7 +33,6 @@ template <typename AsDataType_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
@@ -59,7 +58,6 @@ struct CShuffleEpilogueProblem
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -658,152 +656,8 @@ struct CShuffleEpilogue
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /* p_smem */,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
static_assert(MPerXdl % RowsPerLane == 0,
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
constexpr int kM0 = MWave;
constexpr int kM2 = RowsPerLane;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
// Tiles to hold row/col scales when present
using SMType = typename ScaleDataType<ScaleM>::DataType;
using SNType = typename ScaleDataType<ScaleN>::DataType;
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if non-scalar scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_m, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_n, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
// Slice accumulators for this M repeat into the permuted layout
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If non-scalar scales provided, load them with identical distribution
if constexpr(has_scales && !has_scalar_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
}
// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t plane = c_warp_y_lengths.product();
// local lambda to fuse scale (if present) and convert
static_for<0, kM2, 1>{}([&](auto m_lane) {
const int src = n_idx * plane + m_lane; // source row in this N-plane
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
AccDataType v = shuffle_acc.get_thread_buffer()[src];
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales && !has_scalar_scales)
{
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
v = static_cast<AccDataType>(v * sm * sn);
}
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
});
});
// store/update
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,

View File

@@ -0,0 +1,375 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <type_traits>
namespace ck_tile {
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
struct PermuteNEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct PermuteNEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
std::is_same_v<ADataType, pk_fp4_t>,
BDataType,
ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
// PermuteN epilogue does not support D tensors or non-passthrough elementwise operations.
// If D tensor support is needed, use CShuffleEpilogue instead.
static_assert(NumDTensor == 0,
"PermuteNEpilogue does not support D tensors. Use CShuffleEpilogue instead.");
static_assert(std::is_same_v<CDElementwise, element_wise::PassThrough>,
"PermuteNEpilogue only supports PassThrough elementwise. "
"Use CShuffleEpilogue for custom elementwise operations.");
CK_TILE_DEVICE PermuteNEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "PermuteNEpilogue",
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
isCTransposed ? "CTransposed" : "CNotTransposed");
// clang-format on
}
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
struct EmptyScale
{
};
template <typename, typename = void>
struct ScaleDataType
{
using DataType = float;
};
template <typename T>
struct ScaleDataType<T, std::void_t<typename T::DataType>>
{
using DataType = typename T::DataType;
};
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /* p_smem */,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
static_assert(MPerXdl % RowsPerLane == 0,
"PermuteN: MPerXdl must be divisible by per-lane row count.");
constexpr int kM0 = MWave;
constexpr int kM2 = RowsPerLane;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
// Tiles to hold row/col scales when present
using SMType = typename ScaleDataType<ScaleM>::DataType;
using SNType = typename ScaleDataType<ScaleN>::DataType;
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if non-scalar scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_m, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_n, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
// Slice accumulators for this M repeat into the permuted layout
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If non-scalar scales provided, load them with identical distribution
if constexpr(has_scales && !has_scalar_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
}
// Pack "rows per lane" with permuted N layout
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t plane = c_warp_y_lengths.product();
// Fuse scale (if present) and convert
static_for<0, kM2, 1>{}([&](auto m_lane) {
const int src = n_idx * plane + m_lane; // source row in this N-plane
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
AccDataType v = shuffle_acc.get_thread_buffer()[src];
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales && !has_scalar_scales)
{
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
v = static_cast<AccDataType>(v * sm * sn);
}
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
});
});
// store/update
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
};
} // namespace ck_tile

View File

@@ -339,16 +339,6 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
{
return hostArgs;
}
// CK_TILE_HOST static constexpr auto
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
// {
// return hostArgs;
// }
// CK_TILE_HOST static constexpr auto
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
// {
// return hostArgs;
// }
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,

View File

@@ -483,13 +483,6 @@ struct MoeFlatmmKernel
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
// if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
// {
// std::cerr << "Can't support N that is not a multiple of NPerBlock"
// " without padding!"
// << std::endl;
// return false;
// }
if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
{
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;

View File

@@ -392,10 +392,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,

View File

@@ -1151,11 +1151,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
// {
// block_sync_lds();
// }
});
}
});
@@ -1636,10 +1631,6 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
? Aload_rep
: 0;
}
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
// {
// load_perM = load_perM + 1;
// }
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}

View File

@@ -103,13 +103,8 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
// static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread *
// ContinuousScaleKPerThread; static constexpr index_t ScaleBload_num =
// kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
// WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
@@ -352,10 +347,6 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
? Aload_rep
: 0;
}
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
// {
// load_perM = load_perM + 1;
// }
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}

View File

@@ -390,10 +390,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
? Aload_rep
: 0;
}
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
// {
// load_perM = load_perM + 1;
// }
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}

View File

@@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
const index_t seqlen_k = [&]() {
// WA i_batch capture structure binding before c++20
const index_t seqlen_k = [&, i_batch_ = i_batch]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
const int32_t page_start = kargs.page_table.kv_indptr[i_batch_];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1];
const int32_t num_page_blocks = page_end - page_start;
const int32_t last_page_len = [&]() {
if constexpr(kPageBlockSize == 1)
return static_cast<int32_t>(kPageBlockSize);
else
return kargs.page_table.kv_last_page_lens[i_batch];
return kargs.page_table.kv_last_page_lens[i_batch_];
}();
return num_page_blocks > 0
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
@@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
if(kargs.page_table.seqlen_k_ptr != nullptr)
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch_]);
else
return kargs.seqlen_k;
}
}();
const int32_t* page_idx = [&]() {
// WA i_batch capture structure binding before c++20
const int32_t* page_idx = [&, i_batch_ = i_batch]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_];
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
return kargs.page_table.block_table_ptr +
static_cast<long_index_t>(i_batch) *
static_cast<long_index_t>(i_batch_) *
kargs.page_table.batch_stride_block_table;
}
}();

View File

@@ -39,6 +39,9 @@ struct FmhaFwdKernel
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
template <typename T>
using has_hdim_tail_args = decltype(T::kUseHdimTailArgs);
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
@@ -1891,6 +1894,35 @@ struct FmhaFwdKernel
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
constexpr bool kPassHdimTailArgs = [] {
if constexpr(ck_tile::is_detected<has_hdim_tail_args, FmhaPipeline>::value)
return static_cast<bool>(FmhaPipeline::kUseHdimTailArgs);
else
return false;
}();
auto invoke_fmha_pipeline = [&](auto&&... args) -> decltype(auto) {
if constexpr(kPassHdimTailArgs)
{
const ck_tile::index_t valid_k0_loops =
ck_tile::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0);
const ck_tile::index_t valid_last_k0_length =
kargs.hdim_q - (valid_k0_loops - 1) * FmhaPipeline::kK0;
const ck_tile::index_t valid_n1_length = [&]() {
const ck_tile::index_t remaining_n1 = kargs.hdim_v - i_n1;
return ck_tile::min(remaining_n1,
static_cast<ck_tile::index_t>(FmhaPipeline::kN1));
}();
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)...,
sink_value,
valid_k0_loops,
valid_last_k0_length,
valid_n1_length);
}
else
{
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)..., sink_value);
}
};
auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
@@ -1910,36 +1942,35 @@ struct FmhaFwdKernel
else
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_value);
return invoke_fmha_pipeline(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()));
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
@@ -1964,7 +1995,7 @@ struct FmhaFwdKernel
// Both P and rowsum are scaled by 2^shift, canceling in normalization
// No additional scaling needed in p_compute_element_func or o_acc_element_func
return FmhaPipeline{}(
return invoke_fmha_pipeline(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
@@ -1992,8 +2023,7 @@ struct FmhaFwdKernel
kargs.block_scale_size_kv,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_value);
make_null_tile_window(make_tuple()));
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
{
@@ -2098,53 +2128,51 @@ struct FmhaFwdKernel
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
{i_n1, 0});
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
q_scale_dram_window,
k_scale_dram_window,
v_scale_dram_window,
sink_value);
return invoke_fmha_pipeline(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
q_scale_dram_window,
k_scale_dram_window,
v_scale_dram_window);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
sink_value);
return invoke_fmha_pipeline(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
}();

View File

@@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
@@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
auto k_coord = k_dist.calculate_index();
@@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
// kPageBlockSize < kN0: global offset, must fit int32
statically_indexed_array<index_t, NRepeat> k_offsets;
index_t current_seq_k = seqlen_k_start;
index_t current_seq_k = kv_load_start;
// Load physical pages first, then compute offsets.
// k_physical_pages can be reused for descale lookup later.
@@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
@@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
v_dist,
v_offsets,
number<1>{}, // HsGatherDim
@@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == num_sink_loop - 1)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(
std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
index_t seq_offset = [&]() {
if constexpr(kHasSink)
{
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window,
{0, seqlen_k_start - sink_seq_end});
return in_sink_phase
? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}
else
return seqlen_k_start + i_total_loops * kN0;
}();
dropout
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
@@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
current_seq_k += kN0;
// For sink: after the last sink tile, jump K/V to seqlen_k_start;
// otherwise advance by one normal tile.
const index_t k_advance = [&]() -> index_t {
if constexpr(kHasSink)
return (i_total_loops == num_sink_loop)
? (seqlen_k_start - sink_seq_end + kN0)
: kN0;
else
return kN0;
}();
current_seq_k += k_advance;
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
move_tile_window(k_dram_block_window, {k_advance, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
// KV_BLOCKSCALE: reload physical pages for the new tile
@@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
// After sink→window transition (i_total_loops == num_sink_loop), V window
// was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance
// = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k.
if constexpr(kHasSink)
{
if(i_total_loops == num_sink_loop && num_sink_loop > 0)
{
prefetch_v_physical_pages(number<0>{});
update_v_offsets(number<0>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
}
}
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();

View File

@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
QRKSVS_HPAD,
};
template <BlockFmhaPipelineEnum>
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
static constexpr const char* name = "qr_async_trload";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
{
static constexpr const char* name = "qr_hpad";
};
} // namespace ck_tile

View File

@@ -14,7 +14,9 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
bool PaddedVecLoadStore_ = false>
struct BlockFmhaPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
@@ -37,6 +39,11 @@ struct BlockFmhaPipelineQRKSVS
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
template <typename T>
using has_partial_k_support = decltype(T::kSupportsPartialK);
template <typename T>
using has_partial_n_support = decltype(T::kSupportsPartialN);
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
@@ -54,17 +61,19 @@ struct BlockFmhaPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
static constexpr bool kUseHdimTailArgs = kPadHeadDimQ || kPadHeadDimV;
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -80,23 +89,29 @@ struct BlockFmhaPipelineQRKSVS
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
"padded vector load/store fast path only applies to padded head-dim kernels");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<QDataType>::PackedSize
: Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
? numeric_traits<KDataType>::PackedSize
: Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
return (kPadHeadDimV && !kPaddedVecLoadStore)
? 1
: Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
: Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
@@ -194,7 +209,10 @@ struct BlockFmhaPipelineQRKSVS
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
const VScaleDramBlockWindowTmp&
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
const float sink_v) const
const float sink_v,
const index_t valid_k0_loops,
const index_t valid_last_k0_length,
const index_t valid_n1_length) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -252,8 +270,30 @@ struct BlockFmhaPipelineQRKSVS
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
using BlockGemm1 = remove_cvref_t<decltype(gemm_1)>;
constexpr bool kBlockGemm0SupportsPartialK = [] {
if constexpr(ck_tile::is_detected<has_partial_k_support, BlockGemm0>::value)
return static_cast<bool>(BlockGemm0::kSupportsPartialK);
else
return false;
}();
constexpr bool kBlockGemm1SupportsPartialN = [] {
if constexpr(ck_tile::is_detected<has_partial_n_support, BlockGemm1>::value)
return static_cast<bool>(BlockGemm1::kSupportsPartialN);
else
return false;
}();
constexpr auto gemm_0_config =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using Gemm0WarpGemm = remove_cvref_t<decltype(gemm_0_config.template at<0>())>;
constexpr index_t kGemm0WarpK = Gemm0WarpGemm::kK;
constexpr index_t kGemm0KItersPerBlock = kK0 / kGemm0WarpK;
constexpr bool kUsePartialKForGemm0Tail =
kPadHeadDimQ && kBlockGemm0SupportsPartialK && (kGemm0KItersPerBlock > 1);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
@@ -419,10 +459,26 @@ struct BlockFmhaPipelineQRKSVS
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Number of k0 iterations prefetched ahead of the current compute iteration.
// The skip decision must be made this many iterations before the last k0 loop.
constexpr index_t kK0PrefetchDepth = 2;
const index_t gemm0_tail_k_iters = [&]() {
if constexpr(kUsePartialKForGemm0Tail)
{
return ck_tile::integer_divide_ceil(valid_last_k0_length, kGemm0WarpK);
}
return static_cast<index_t>(kGemm0KItersPerBlock);
}();
const bool skip_last_k0_loop = [&]() {
if constexpr(kPadHeadDimQ)
{
return valid_k0_loops == (k0_loops - 1);
}
return false;
}();
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm_0 = [] {
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
@@ -447,7 +503,7 @@ struct BlockFmhaPipelineQRKSVS
}
};
static_assert(2 <= k0_loops);
static_assert(kK0PrefetchDepth <= k0_loops);
static_assert(1 <= k1_loops);
do
{
@@ -514,6 +570,46 @@ struct BlockFmhaPipelineQRKSVS
}
auto run_gemm_0 = [&](auto i_k0) {
if constexpr(kUsePartialKForGemm0Tail)
{
if(static_cast<index_t>(i_k0.value) == (valid_k0_loops - 1) &&
gemm0_tail_k_iters < kGemm0KItersPerBlock)
{
static_for<1, kGemm0KItersPerBlock, 1>{}([&](auto i_tail_k_iter) {
constexpr index_t kTailKIters = i_tail_k_iter;
constexpr index_t kTailK0 = kTailKIters * kGemm0WarpK;
if(gemm0_tail_k_iters == kTailKIters)
{
using Gemm0TailProblem = BlockGemmProblem<
QDataType,
KDataType,
SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<kM0, kN0, kTailK0>,
typename BlockFmhaShape::Gemm0BlockWarps,
sequence<BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
kGemm0WarpK>>>;
constexpr auto gemm_0_tail =
BlockGemmARegBSmemCRegV2<Gemm0TailProblem,
typename BlockGemm0::Policy>{};
auto q_slice =
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, i_k0 * kK0 + kTailK0>{});
auto k_tail_window = make_tile_window(
k_lds, make_tuple(number<kN0>{}, number<kTailK0>{}), {0, 0});
gemm_0_tail(s_acc, q_slice, k_tail_window);
}
});
return;
}
}
auto q_slice = get_slice_tile(
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{});
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
@@ -531,36 +627,78 @@ struct BlockFmhaPipelineQRKSVS
}
};
if constexpr(k0_loops > 2)
if constexpr(k0_loops > kK0PrefetchDepth)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
static_for<0, k0_loops - kK0PrefetchDepth, 1>{}([&](auto i_k0) {
block_sync_lds();
run_gemm_0(number<i_k0>{});
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
if constexpr(kPadHeadDimQ && i_k0 == (k0_loops - 1 - kK0PrefetchDepth))
{
if(!skip_last_k0_loop)
{
move_tile_window(k_dram_window, {0, kK0});
}
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
if(!skip_last_k0_loop)
{
k_block_tile = load_tile(k_dram_window); // global read i + 2
}
}
else
{
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
}
k_scale_block_tile = load_k_scale_block_tile();
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
run_gemm_0(number<k0_loops - 2>{});
auto v_prefetch = decltype(load_tile(v_dram_window)){};
enum class VPrefetchPoint
{
BeforeGemm0Tail,
AfterGemm0Tail,
AfterSoftmax
};
#if defined(__gfx11__) || defined(__gfx12__)
constexpr auto kVPrefetch =
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
#else
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
#endif
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
{
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
}
{ // tail
block_sync_lds();
run_gemm_0(number<k0_loops - kK0PrefetchDepth>{});
if(!skip_last_k0_loop)
{
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_scale_block_tile = load_k_scale_block_tile();
k_scale_block_tile = load_k_scale_block_tile();
block_sync_lds();
block_sync_lds();
run_gemm_0(number<k0_loops - 1>{});
run_gemm_0(number<k0_loops - 1>{});
}
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
{
load_tile(v_prefetch, v_dram_window);
}
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
@@ -819,6 +957,11 @@ struct BlockFmhaPipelineQRKSVS
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
{
load_tile(v_prefetch, v_dram_window);
}
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -898,6 +1041,31 @@ struct BlockFmhaPipelineQRKSVS
auto o_acc0 = decltype(o_acc){};
clear_tile(o_acc0);
constexpr auto gemm_1_config =
BlockGemm1::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using Gemm1WarpGemm = remove_cvref_t<decltype(gemm_1_config.template at<0>())>;
constexpr index_t kGemm1NWarp = gemm_1_config.template at<2>();
constexpr index_t kGemm1NPerIter = kGemm1NWarp * Gemm1WarpGemm::kN;
const index_t valid_n_iters = [&]() {
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
{
return ck_tile::integer_divide_ceil(valid_n1_length, kGemm1NPerIter);
}
return static_cast<index_t>(0);
}();
auto run_gemm_1_impl =
[&](auto& o_acc_tensor, const auto& p_slice, const auto&... gemm_1_args) {
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
{
gemm_1(o_acc_tensor, p_slice, gemm_1_args..., valid_n_iters);
}
else
{
gemm_1(o_acc_tensor, p_slice, gemm_1_args...);
}
};
auto run_gemm_1 = [&](auto i_k1) {
auto p_slice =
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{});
@@ -907,15 +1075,19 @@ struct BlockFmhaPipelineQRKSVS
get_slice_tile(p_scale,
sequence<0, i_k1*(kK1 / kVScaleGranularity)>{},
sequence<kM0, (i_k1 + 1) * (kK1 / kVScaleGranularity)>{});
gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
gemm_1(o_acc0, p_slice, v_lds_window);
run_gemm_1_impl(
o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
}
else
{
gemm_1(o_acc, p_slice, v_lds_window);
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
run_gemm_1_impl(o_acc0, p_slice, v_lds_window);
}
else
{
run_gemm_1_impl(o_acc, p_slice, v_lds_window);
}
}
};
@@ -1040,6 +1212,94 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices,
typename QScaleDramBlockWindowTmp,
typename KScaleDramBlockWindowTmp,
typename VScaleDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float* k_descale_ptr,
const float* v_descale_ptr,
const index_t block_scale_size_kv,
const QScaleDramBlockWindowTmp&
q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile
const KScaleDramBlockWindowTmp&
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
const VScaleDramBlockWindowTmp&
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
q_element_func,
k_dram_block_window_tmp,
k_element_func,
v_dram_block_window_tmp,
v_element_func,
bias_dram_block_window_tmp,
bias_element_func,
randval_dram_block_window_tmp,
lse_dram_window_tmp,
lse_element_func,
s_acc_element_func,
p_compute_element_func,
o_acc_element_func,
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
k_descale_ptr,
v_descale_ptr,
block_scale_size_kv,
q_scale_dram_block_window_tmp,
k_scale_dram_block_window_tmp,
v_scale_dram_block_window_tmp,
sink_v,
kQKHeaddim / kK0,
kK0,
kN1);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
@@ -1064,7 +1324,10 @@ struct BlockFmhaPipelineQRKSVS
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
const float sink_v,
const index_t valid_k0_loops,
const index_t valid_last_k0_length,
const index_t valid_n1_length) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -1094,8 +1357,60 @@ struct BlockFmhaPipelineQRKSVS
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_v);
sink_v,
valid_k0_loops,
valid_last_k0_length,
valid_n1_length);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
sink_v,
kQKHeaddim / kK0,
kK0,
kN1);
}
};
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
} // namespace ck_tile

View File

@@ -692,9 +692,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t BufferSize =
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);

View File

@@ -53,6 +53,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false, /* StreamLLM sink tokens */
index_t kPageBlockSize_ = 1,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
@@ -70,7 +71,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
QScaleEnum_,
kBlockPerCu_,
kSkipMinSeqlenQ_,
false>
kHasSink_>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;

View File

@@ -456,9 +456,6 @@ struct MoeSortingKernel
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
@@ -1206,17 +1203,21 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
else if constexpr(wave_size_ == 8) return 3;
else if constexpr(wave_size_ == 16) return 4;
else if constexpr(wave_size_ == 32) return 5;
else if constexpr(wave_size_ == 64) return 6;
else return 0;
constexpr int reduce_stage = []() {
if constexpr(wave_size_ == 2)
return 1;
else if constexpr(wave_size_ == 4)
return 2;
else if constexpr(wave_size_ == 8)
return 3;
else if constexpr(wave_size_ == 16)
return 4;
else if constexpr(wave_size_ == 32)
return 5;
else if constexpr(wave_size_ == 64)
return 6;
else
return 0;
}();
// clang-format on
T v_local = local;
@@ -3047,53 +3048,6 @@ struct MoeSortingMultiPhaseKernel_P23
x_r = x_v;
#endif
{
#if 0
#pragma unroll
for(int j = 0; j < index_pack / 2; j++)
{
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
index_t x = x_d[j];
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
__syncthreads();
int position = cumsum - i_show;
prev_cumsum = s[0]; // update the last cumsum
if(i_show)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start + position] =
MOE_SORTING_MOCK_ID(i_token, i_topk);
#else
p_sorted_token_ids[e_start + position] = i_token;
#endif
p_sorted_weights[e_start + position] =
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
}
}
#endif
{
d_t i_topk;
d_t i_show;
@@ -3151,68 +3105,6 @@ struct MoeSortingMultiPhaseKernel_P23
}
position += i_show[j];
});
#if 0
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
index_t x = x_d[j];
index_t x0 = static_cast<index_t>(x & 0xffff);
index_t x1 = static_cast<index_t>(x >> 16);
int i_topk_0 = x0 - 1; // topk of this token
int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not
int i_topk_1 = x1 - 1; // topk of this token
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
int cumsum = i_show_0 + i_show_1;
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
__syncthreads();
int position_0 = cumsum - i_show_0 - i_show_1;
prev_cumsum = s[0]; // update the last cumsum
if(i_show_0)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start + position_0] =
MOE_SORTING_MOCK_ID(i_token, i_topk_0);
#else
p_sorted_token_ids[e_start + position_0] = i_token;
#endif
p_sorted_weights[e_start + position_0] =
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
}
int position_1 = cumsum - i_show_1;
if(i_show_1)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start + position_1] =
MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1);
#else
p_sorted_token_ids[e_start + position_1] = i_token + 1;
#endif
p_sorted_weights[e_start + position_1] =
p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
}
#endif
}
}
}

View File

@@ -14,14 +14,6 @@
namespace ck_tile {
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,

View File

@@ -36,9 +36,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;

View File

@@ -19,30 +19,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&

View File

@@ -21,13 +21,18 @@ struct BlockGemmARegBSmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool kSupportsPartialK = true;
static constexpr bool kSupportsPartialN = true;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
template <bool UsePartialN,
typename CBlockTensor,
typename ABlockTensorTmp,
typename BBlockWindowTmp>
CK_TILE_DEVICE void Impl(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
[[maybe_unused]] const index_t valid_n_iters) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
@@ -134,10 +139,7 @@ struct BlockGemmARegBSmemCRegV2
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
auto run_n_iter = [&](auto kIter, auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
@@ -166,7 +168,44 @@ struct BlockGemmARegBSmemCRegV2
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
};
// hot loop:
if constexpr(UsePartialN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if(static_cast<index_t>(nIter.value) < valid_n_iters)
{
run_n_iter(kIter, nIter);
}
});
});
}
else
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { run_n_iter(kIter, nIter); });
});
}
}
// C += A * B (executing only the first valid_n_iters N sub-iterations)
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
const index_t valid_n_iters) const
{
Impl<true>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
}
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
Impl<false>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, 0);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
@@ -227,7 +266,17 @@ struct BlockGemmARegBSmemCRegV2
return c_block_tensor;
}
// C = A * B
// C = A * B (executing only the first valid_n_iters N sub-iterations)
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
const index_t valid_n_iters) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
return c_block_tensor;
}
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const

View File

@@ -16,30 +16,7 @@ struct BlockGemmARegBSmemCRegV2DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
}
};

View File

@@ -19,30 +19,7 @@ struct BlockGemmASmemBRegCRegV1DefaultPolicy
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&

View File

@@ -5,52 +5,9 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
namespace ck_tile {
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
* @param a_vec Vector to be compressed.
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename WarpGemmAttribute_>
struct WarpGemmSmfmacImpl
{
@@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl
return WarpGemmAttribute_::get_num_of_access();
}
template <index_t CompressedSize, typename AVec>
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
//----------------------------------------------------------------------------------------------
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
/// elements into lower part of a_vec to half its effective size.
///
/// @param a_vec Vector to be compressed.
///
/// @return Four 2-bit indexes of non-zero elements locations
///
template <typename AVec>
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
{
return compress_a_impl<ADataType, CompressedSize>(a_vec);
int32_t idx = 0b11101110;
static_for<0, 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
int32_t non_zero_pos = 0;
static_for<0, 3, 1>{}([&](auto j) {
if(a_vec[i * 4 + j] != 0.0f)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
idx |= j << 2 * (i * 2 + non_zero_pos);
++non_zero_pos;
}
});
a_vec[i * 2] = nonzero_elems[0];
a_vec[i * 2 + 1] = nonzero_elems[1];
});
return idx;
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
@@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
static constexpr index_t CompressedSize =
ATensor::get_thread_buffer_size() / CompressionRatio;
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
using AVecCompressed =
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
const int32_t idx = compress_a(a_vec);
static_assert(CompressedSize == 4);
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};

View File

@@ -120,10 +120,6 @@ struct BlockNormReduceSync
constexpr index_t idim_p_lane = NDimP - 1;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
@@ -360,17 +356,6 @@ struct BlockNormReduceCrossWarpSync
template <typename BlockShape>
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
{
#if 0
using S = BlockShape;
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
index_t iNLane = get_thread_id() % NThread;
index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
return iN0 * S::Vector_N + iN3;
#endif
using S_ = BlockShape;
constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;

View File

@@ -140,28 +140,6 @@ struct BlockReduce2d
ReducePacksPerXDim{});
}
#if 0
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
auto y = y_tensor[y_dstr_idx];
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
y = reduce_func(y, x);
});
y_tensor(y_dstr_idx) = y;
});
#endif
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
{
@@ -240,10 +218,6 @@ struct BlockReduce2dSync
constexpr index_t idim_p_lane = NDimP - 1;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// loop over thread data

View File

@@ -0,0 +1,45 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file tile_load_store_microkernels.hpp
* @brief Generic tile store/load microkernels.
*
* Setup::create() must return:
* - For StoreTile: tuple<window, tile>
* - For LoadTile: window
*/
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Setup>
struct StoreTile
{
static constexpr index_t kBlockSize = Setup::kBlockSize;
CK_TILE_DEVICE void operator()() const
{
auto [window, tile] = Setup::create();
store_tile(window, tile);
block_sync_lds();
}
};
template <typename Setup>
struct LoadTile
{
static constexpr index_t kBlockSize = Setup::kBlockSize;
CK_TILE_DEVICE void operator()() const
{
auto window = Setup::create();
[[maybe_unused]] volatile auto tile = load_tile(window);
block_sync_lds();
}
};
} // namespace ck_tile