mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Merge branch 'develop' into users/yiding12/fmha-bwd-workspace
This commit is contained in:
@@ -19,19 +19,29 @@
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
|
||||
|
||||
@@ -2166,27 +2166,11 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if 0
|
||||
thread_buffer<fp16_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(fp16_t),
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
|
||||
|
||||
@@ -1992,27 +1992,11 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if 0
|
||||
thread_buffer<fp16_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(fp16_t),
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -87,7 +89,7 @@ namespace ck_tile::core::arch::mma {
|
||||
*
|
||||
* (logical correctness). Applies to scale MFMA fp8, which due to the index matrix layout does not
|
||||
* allow arbitrary K perms to simplify layouts. This means the layout can only properly be described
|
||||
* with a Num Access value of at least 2.
|
||||
* with a Num Access value which is a multiple of 2.
|
||||
*
|
||||
* (load / store manipulation). It seems like the load and store tile functions end up looking for
|
||||
* the size of the smallest unmerged K dimension (K0) to determine how many elements should be
|
||||
@@ -102,13 +104,16 @@ namespace ck_tile::core::arch::mma {
|
||||
*
|
||||
* -- CMPerLane --
|
||||
* The number of M dim elements in each lane. In terms of unmerge sizes, it's equal to M0 * M2, i.e
|
||||
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
|
||||
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge. This
|
||||
* does not count a potential increased M dimension size from block hiding. In this case, we have M
|
||||
* = kCMBlock * M2 * M1 * M0 instead.
|
||||
*
|
||||
* -- CNumAccess --
|
||||
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
|
||||
* and will not try to request a specific value. Absolutely needed for logical correctness of
|
||||
* register mappings since we can not perform arbitrary M permutations without messing up the A
|
||||
* layout.
|
||||
* layout. This does not count a potential increased M dimension size from block hiding. In this
|
||||
* case, we have M = kCMBlock * M2 * M1 * M0 instead.
|
||||
*/
|
||||
|
||||
/**
|
||||
@@ -144,7 +149,7 @@ struct amdgcn_mma_base
|
||||
using CDataType = CDataType_;
|
||||
|
||||
// Fragment (MmaTile) sizes, check description above.
|
||||
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
|
||||
static constexpr index_t kM = FragM; // M = M2 * M1 * M0 (* kCMBlocks when block-hiding)
|
||||
static constexpr index_t kN = FragN;
|
||||
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
|
||||
|
||||
@@ -157,15 +162,37 @@ struct amdgcn_mma_base
|
||||
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
|
||||
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
|
||||
|
||||
// K-dimension compression ratio for A matrix, always 2 for sparse intrinsics.
|
||||
static constexpr index_t kCompressionRatio = (OpFamily == MmaOpFamily::SPARSE) ? 2 : 1;
|
||||
|
||||
// Layout checks
|
||||
static_assert(kK % kABKPerLane == 0);
|
||||
static_assert(kABKPerLane % kAKNumAccess == 0);
|
||||
static_assert(kABKPerLane % kBKNumAccess == 0);
|
||||
static_assert(kCMPerLane % kCMNumAccess == 0);
|
||||
|
||||
// Register types (derived)
|
||||
static constexpr index_t WaveSize = WaveSize_;
|
||||
static_assert((kM * kK * kARepeat) % WaveSize == 0);
|
||||
static_assert((kM * kK * kARepeat) % (WaveSize * kCompressionRatio) == 0);
|
||||
static_assert((kN * kK * kBRepeat) % WaveSize == 0);
|
||||
static_assert((kM * kN) % WaveSize == 0);
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize>;
|
||||
using AVecType = ext_vector_t<ADataType, kM * kK * kARepeat / WaveSize / kCompressionRatio>;
|
||||
using BVecType = ext_vector_t<BDataType, kN * kK * kBRepeat / WaveSize>;
|
||||
using CVecType = ext_vector_t<CDataType, kM * kN / WaveSize>;
|
||||
|
||||
// Block-hiding / repeat related traits (derived)
|
||||
static_assert(kARepeat == kBRepeat || !std::is_same_v<OpType, WmmaOp>);
|
||||
static_assert(kARepeat == 1 || kBRepeat == 1 || !std::is_same_v<OpType, MfmaOp>);
|
||||
static constexpr index_t kCMBlocks = std::is_same_v<OpType, MfmaOp> ? kBRepeat : 1;
|
||||
static constexpr index_t kCNBlocks = std::is_same_v<OpType, MfmaOp> ? kARepeat : 1;
|
||||
static_assert(kM % (kCMBlocks * kCMPerLane) == 0);
|
||||
static_assert(kN % kCNBlocks == 0);
|
||||
|
||||
// For the C matrix, the block dimension B is either put in the Vector dimension or the Lane
|
||||
// dimension. We can tell which by checking if we get the right Vector size.
|
||||
static constexpr bool CBlockDimInVecDim =
|
||||
kCMBlocks * kCNBlocks * kCMPerLane == vector_traits<CVecType>::vector_size;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -177,15 +204,30 @@ struct Unsupported;
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
#include <concepts>
|
||||
/**
|
||||
* @concept HasExecSignature
|
||||
* @brief Helper concept for exec signature check.
|
||||
*/
|
||||
template <typename MmaOp, typename... ExecArgs>
|
||||
concept HasExecSignature = requires {
|
||||
{
|
||||
MmaOp::exec(typename MmaOp::AVecType{},
|
||||
typename MmaOp::BVecType{},
|
||||
typename MmaOp::CVecType{},
|
||||
std::declval<ExecArgs>()...)
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
// TODO: Make sure this actually matches amdgcn_mma.
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
typename MmaOp::OpType;
|
||||
typename MmaOp::OpFamily;
|
||||
{ MmaOp::OpFamily } -> std::convertible_to<MmaOpFamily>;
|
||||
|
||||
// Captures types for inputs / outputs to mma function
|
||||
typename MmaOp::ADataType;
|
||||
@@ -194,7 +236,6 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAKNumAccess } -> std::convertible_to<unsigned int>;
|
||||
@@ -203,13 +244,8 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
{ MmaOp::kBRepeat } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
|
||||
|
||||
// Static exec function
|
||||
{
|
||||
MmaOp::exec(
|
||||
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
|
||||
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
@@ -248,7 +284,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
|
||||
// clang-format on
|
||||
{
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
CK_TILE_DEVICE static CVecType const&
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
// Prints once across all thread blocks and threads.
|
||||
@@ -267,6 +303,8 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "wmma/wmma.hpp" // should be included before the below headers
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "scale/scale.hpp"
|
||||
#include "sparse/sparse.hpp"
|
||||
|
||||
@@ -51,6 +51,82 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 32u, 4u, 64u, 4, 1, 1, 1, 2, 16, 4, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 64u, 4u, 64u, 4, 1, 2, 1, 1, 16, 4, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 64u, 4u, 4u, 64u, 4, 1, 1, 1, 16, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 4u, 64u, 4u, 64u, 4, 1, 16, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
|
||||
@@ -28,9 +28,11 @@ struct MmaDefaultTransformsGfx9
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
struct MmaTransformsDefaultSelector<
|
||||
MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx9;
|
||||
};
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "wmma/wmma.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaAccumPolicy
|
||||
* @brief Accumulation order for Mma decomposition
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major fragment order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major fragment order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
|
||||
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
|
||||
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
|
||||
* @tparam ADataType Data type of input WaveTile A
|
||||
* @tparam BDataType Data type of input WaveTile B
|
||||
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
|
||||
* @tparam WaveTileM Mma WaveTile M dimension
|
||||
* @tparam WaveTileN Mma WaveTile K dimension
|
||||
* @tparam WaveTileK Mma WaveTile M dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
{
|
||||
using FragWiseMmaOp = MmaOp;
|
||||
|
||||
// Fragment dimensions
|
||||
constexpr static uint32_t FragM = MmaOp::kM;
|
||||
constexpr static uint32_t FragN = MmaOp::kN;
|
||||
constexpr static uint32_t FragK = MmaOp::kK;
|
||||
|
||||
// Fragment counts for decomposition
|
||||
constexpr static uint32_t FragsM = WaveTileM / FragM;
|
||||
constexpr static uint32_t FragsN = WaveTileN / FragN;
|
||||
constexpr static uint32_t FragsK = WaveTileK / FragK;
|
||||
constexpr static uint32_t FragsC = FragsM * FragsN;
|
||||
|
||||
// Vector types for packed registers in each fragment
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Buffer types for WaveTiles
|
||||
using ABufferType = AVecType[FragsM][FragsK];
|
||||
using BBufferType = BVecType[FragsN][FragsK];
|
||||
using CBufferType = CVecType[FragsM][FragsN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
|
||||
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
|
||||
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
|
||||
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
|
||||
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
|
||||
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
|
||||
|
||||
private:
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input WaveTiles to the native vector types
|
||||
// required by the FragWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT const&>(inputBuffer);
|
||||
}
|
||||
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input WaveTiles to the native vector types
|
||||
// required by the FragWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT&>(inputBuffer);
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Col-major" accumulation over the M-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the fragments in col-major order
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output WaveTile format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input WaveTiles,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Row-major" accumulation over the N-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the fragments in row-major order.
|
||||
// We also have to ensure that the incoming vector WaveTiles are converted to native vector
|
||||
// types before passing to the FragWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output WaveTile format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
public:
|
||||
/*! @brief Forward to Mma operation with specified accumulation order.
|
||||
* @tparam VecTA The input WaveTile A vector type
|
||||
* @tparam VecTB The input WaveTile B vector type
|
||||
* @tparam VecTC The input/output WaveTile C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
|
||||
{
|
||||
return exec_row_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
|
||||
{
|
||||
return exec_col_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
343
include/ck_tile/core/arch/mma/mma_pipeline.hpp
Normal file
343
include/ck_tile/core/arch/mma/mma_pipeline.hpp
Normal file
@@ -0,0 +1,343 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_traits.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaPipelineOptionFlag
|
||||
* @brief Individual option flags for configuring MmaPipeline behavior.
|
||||
*/
|
||||
enum struct MmaPipelineOptionFlag : unsigned
|
||||
{
|
||||
NONE = 0x0, ///< No flags set
|
||||
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
|
||||
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaPipelineOptionFlags
|
||||
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
|
||||
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
|
||||
* and querying pipeline option flags.
|
||||
*/
|
||||
struct MmaPipelineOptionFlags
|
||||
{
|
||||
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
|
||||
|
||||
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
|
||||
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
|
||||
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
|
||||
{
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
|
||||
: mFlags(original.mFlags)
|
||||
{
|
||||
}
|
||||
|
||||
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
|
||||
{
|
||||
mFlags |= toType(addValue);
|
||||
return *this;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result |= addValue;
|
||||
return result;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
|
||||
{
|
||||
mFlags &= toType(maskValue);
|
||||
return *this;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result &= maskValue;
|
||||
return result;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator~() const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result.mFlags = ~result.mFlags;
|
||||
return result;
|
||||
}
|
||||
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
|
||||
{
|
||||
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
|
||||
}
|
||||
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
|
||||
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
|
||||
|
||||
private:
|
||||
Type mFlags;
|
||||
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
|
||||
};
|
||||
|
||||
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
|
||||
{
|
||||
return rhs == lhs;
|
||||
}
|
||||
|
||||
/**
|
||||
* @class MmaPipelineBase
|
||||
* @brief CRTP base class that implements the common Mma pipeline logic shared by
|
||||
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
|
||||
*
|
||||
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
|
||||
* pipeline behavior (e.g., C transposition, A compression).
|
||||
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
|
||||
* - Type aliases: @c InternalAVecT, @c InternalBVecT, @c InternalCVecT,
|
||||
* @c CVecType, @c MmaOp
|
||||
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform,
|
||||
* @c DTransform
|
||||
* - A static @c execImpl(std::tuple<A,B,C>&) method.
|
||||
*
|
||||
* @par The pipeline performs the following steps in @c exec():
|
||||
* 1. Apply pre-transforms and format input buffers (A, B, C).
|
||||
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
|
||||
* 3. Apply post-transform and format the output buffer (D) back to the user type.
|
||||
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
|
||||
*/
|
||||
// TODO: c++20: use MmaPipelineOptionFlags directly
|
||||
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
|
||||
struct MmaPipelineBase
|
||||
{
|
||||
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Reconstruct a tuple with its first element passed through @c formatBuffer
|
||||
* while preserving all remaining elements unchanged.
|
||||
* @tparam DstT Target type for the formatted first element.
|
||||
* @tparam SrcT Forwarding-reference type of the input tuple.
|
||||
* @tparam Is Index pack for elements 1..N-1 of the tuple.
|
||||
* @param inputTuple The source tuple whose first element will be formatted.
|
||||
* @return A new tuple with the formatted first element and the remaining elements forwarded.
|
||||
*/
|
||||
template <typename DstT, typename SrcT, std::size_t... Is>
|
||||
CK_TILE_DEVICE static auto formatBufferTupleImpl(SrcT&& inputTuple, std::index_sequence<Is...>)
|
||||
{
|
||||
auto&& first_elem = std::get<0>(std::forward<SrcT>(inputTuple));
|
||||
using FirstElemResultType =
|
||||
decltype(formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)));
|
||||
using InputTupleType = ck_tile::remove_cvref_t<SrcT>;
|
||||
return std::tuple<FirstElemResultType, std::tuple_element_t<Is + 1, InputTupleType>...>(
|
||||
formatBuffer<DstT>(std::forward<decltype(first_elem)>(first_elem)),
|
||||
std::get<Is + 1>(std::forward<SrcT>(inputTuple))...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Format (reinterpret-cast) a buffer to the hardware-native vector type @p DstT.
|
||||
*
|
||||
* Three cases are handled:
|
||||
* - **Tuple**: recursively format the first element via @c formatBufferTupleImpl,
|
||||
* preserving any metadata in the remaining tuple elements.
|
||||
* - **Array / Pointer**: forwarded unchanged.
|
||||
* - **Scalar / Vector**: reinterpret-cast to @p DstT (sizes must match).
|
||||
*
|
||||
* @tparam DstT The target hardware vector type.
|
||||
* @tparam SrcT Forwarding-reference type of the input buffer.
|
||||
* @param inputBuffer The buffer to format.
|
||||
* @return A reference (or value) of type @p DstT corresponding to @p inputBuffer.
|
||||
*/
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static decltype(auto) formatBuffer(SrcT&& inputBuffer)
|
||||
{
|
||||
using DecayedSrcT = ck_tile::remove_cvref_t<SrcT>;
|
||||
|
||||
// If SrcT is a tuple, extract the first element (the vector) and format it
|
||||
// while preserving all remaining elements (metadata)
|
||||
if constexpr(is_std_tuple_v<DecayedSrcT>)
|
||||
{
|
||||
// Create index sequence for all remaining elements (skip first)
|
||||
constexpr std::size_t tuple_size = std::tuple_size_v<DecayedSrcT>;
|
||||
return formatBufferTupleImpl<DstT>(std::forward<SrcT>(inputBuffer),
|
||||
std::make_index_sequence<tuple_size - 1>{});
|
||||
}
|
||||
else if constexpr(std::is_array_v<DecayedSrcT> || std::is_pointer_v<DecayedSrcT>)
|
||||
{
|
||||
return std::forward<SrcT>(inputBuffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(DstT) == sizeof(DecayedSrcT), "Size mismatch in formatBuffer");
|
||||
|
||||
using QualifiedDstT =
|
||||
std::conditional_t<std::is_const_v<DecayedSrcT>, DstT const, DstT>;
|
||||
|
||||
return reinterpret_cast<QualifiedDstT&>(inputBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/** @brief Query whether a specific @ref MmaPipelineOptionFlag is set. */
|
||||
template <MmaPipelineOptionFlag Flag>
|
||||
constexpr CK_TILE_DEVICE static bool hasFlag()
|
||||
{
|
||||
return Flags.testFlag(Flag);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply a transform **then** format the result to @p DstT.
|
||||
* Used for input operands (A, B, C) before the mma loop.
|
||||
*/
|
||||
template <typename DstT, typename Transform, typename... Args>
|
||||
CK_TILE_DEVICE static auto preApplyTransform(Args&&... args)
|
||||
{
|
||||
return formatBuffer<DstT>(Transform::exec(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Format a buffer to @p DstT **then** apply a transform.
|
||||
* Used for the output operand (D) after the mma loop.
|
||||
*/
|
||||
template <typename DstT, typename Transform, typename... Args>
|
||||
CK_TILE_DEVICE static auto postApplyTransform(Args&&... args)
|
||||
{
|
||||
return Transform::exec(formatBuffer<DstT>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
|
||||
* @return A @c std::tuple of the transformed (A, B, C, [scaleA, scaleB]) vectors ready for the
|
||||
* mma loop.
|
||||
*/
|
||||
template <typename ATransformInputs,
|
||||
typename BTransformInputs,
|
||||
typename CTransformInputs,
|
||||
typename... ExtraArgs>
|
||||
CK_TILE_DEVICE static decltype(auto) applyTransformsToInputs(ATransformInputs&& a,
|
||||
BTransformInputs&& b,
|
||||
CTransformInputs&& accum,
|
||||
ExtraArgs&&... extras)
|
||||
{
|
||||
using InternalAVecT = typename Derived::InternalAVecT;
|
||||
using InternalBVecT = typename Derived::InternalBVecT;
|
||||
using InternalCVecT = typename Derived::InternalCVecT;
|
||||
|
||||
using ATransform = typename Derived::ATransform;
|
||||
using BTransform = typename Derived::BTransform;
|
||||
using CTransform = typename Derived::CTransform;
|
||||
|
||||
return std::make_tuple(
|
||||
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
|
||||
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
|
||||
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)),
|
||||
std::forward<ExtraArgs>(extras)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
|
||||
* @param c_result The accumulator to post-process.
|
||||
* @return The final D output in the user-facing vector type.
|
||||
*/
|
||||
template <typename CTransformResult>
|
||||
CK_TILE_DEVICE static auto applyTransformToOutput(CTransformResult&& c_result)
|
||||
{
|
||||
static_assert(!is_std_tuple_v<decltype(c_result)>,
|
||||
"If CTransform returns more than the vector, update this function.");
|
||||
|
||||
using CVecT = typename Derived::CVecType;
|
||||
using DTransform = typename Derived::DTransform;
|
||||
return postApplyTransform<CVecT, DTransform>(c_result);
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
|
||||
* @tparam VecTA Type of the A WaveTile buffer.
|
||||
* @tparam VecTB Type of the B WaveTile buffer.
|
||||
* @tparam VecTC Type of the C (accumulator) WaveTile buffer.
|
||||
* @param a Input WaveTile A.
|
||||
* @param b Input WaveTile B.
|
||||
* @param accum Input/output accumulator WaveTile C.
|
||||
* @return The output WaveTile D after accumulation and post-transform.
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
|
||||
{
|
||||
auto transformed_inputs = applyTransformsToInputs(
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
|
||||
: std::forward<VecTA>(a),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
|
||||
: std::forward<VecTB>(b),
|
||||
std::forward<VecTC>(accum));
|
||||
|
||||
Derived::execImpl(transformed_inputs);
|
||||
|
||||
auto&& [a_result, b_result, c_result] = std::move(transformed_inputs);
|
||||
return applyTransformToOutput(std::move(c_result));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
|
||||
// Code should not reach here, but HOST/DEVICE compile passes are
|
||||
// weirdly intertwined and instead of having constexpr in the calling
|
||||
// site (tests) we do this. See also changes by this commit.
|
||||
return Derived::MmaOp::exec({}, {}, {});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VecTA,
|
||||
typename VecTB,
|
||||
typename VecTC,
|
||||
typename ScaleADataType,
|
||||
typename ScaleBDataType>
|
||||
CK_TILE_DEVICE static decltype(auto)
|
||||
exec(VecTA&& a, VecTB&& b, VecTC&& accum, ScaleADataType&& scale_A, ScaleBDataType&& scale_B)
|
||||
{
|
||||
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
|
||||
{
|
||||
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
|
||||
auto transformed_inputs = applyTransformsToInputs(
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
|
||||
: std::forward<VecTA>(a),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
|
||||
: std::forward<VecTB>(b),
|
||||
std::forward<VecTC>(accum),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
|
||||
: std::forward<ScaleADataType>(scale_A),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
|
||||
: std::forward<ScaleBDataType>(scale_B));
|
||||
|
||||
Derived::execImpl(transformed_inputs);
|
||||
|
||||
auto&& [a_result, b_result, c_result, scale_A_result, scale_B_result] =
|
||||
std::move(transformed_inputs);
|
||||
return applyTransformToOutput(std::move(c_result));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
|
||||
// Code should not reach here, but HOST/DEVICE compile passes are
|
||||
// weirdly intertwined and instead of having constexpr in the calling
|
||||
// site (tests) we do this. See also changes by this commit.
|
||||
return Derived::MmaOp::exec({}, {}, {});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept MmaPipelineI
|
||||
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
|
||||
*/
|
||||
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
|
||||
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) {
|
||||
// Include the implementations
|
||||
#include "wmma/wmma_selector.hpp"
|
||||
#include "mfma/mfma_selector.hpp"
|
||||
#include "sparse/sparse_selector.hpp"
|
||||
|
||||
@@ -18,6 +18,18 @@ struct PassThroughTransform
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultPassThroughTransforms
|
||||
* @brief Implements the default MMA transforms
|
||||
*/
|
||||
struct MmaDefaultPassThroughTransforms
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
|
||||
@@ -27,7 +39,10 @@ struct PassThroughTransform
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
|
||||
struct MmaTransformsDefaultSelector;
|
||||
struct MmaTransformsDefaultSelector
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultPassThroughTransforms;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
|
||||
177
include/ck_tile/core/arch/mma/mma_wavewise.hpp
Normal file
177
include/ck_tile/core/arch/mma/mma_wavewise.hpp
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_pipeline.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "wmma/wmma.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaAccumPolicy
|
||||
* @brief Accumulation order for Mma decomposition
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major fragment order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major fragment order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
namespace dense::wavewise::detail {
|
||||
// TODO: c++20: return MmaPipelineOptionFlags directly
|
||||
template <bool SwapAB>
|
||||
constexpr inline int getPipelineFlags()
|
||||
{
|
||||
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
|
||||
}
|
||||
} // namespace dense::wavewise::detail
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to
|
||||
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
|
||||
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
|
||||
* @tparam ADataType Data type of input WaveTile A
|
||||
* @tparam BDataType Data type of input WaveTile B
|
||||
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
|
||||
* @tparam WaveTileM Mma WaveTile M dimension
|
||||
* @tparam WaveTileN Mma WaveTile K dimension
|
||||
* @tparam WaveTileK Mma WaveTile M dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam SwapAB Swaps A and B input vectors
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp_ Backend wrapper class that will perform the mma op (e.g., mfma or wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output frags as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
bool SwapAB = false,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp_ =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
|
||||
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<SwapAB>(),
|
||||
WaveWiseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, OpFamily, AccumPolicy, SwapAB, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
using MmaOp = MmaOp_;
|
||||
|
||||
// Fragment dimensions
|
||||
constexpr static uint32_t FragM = MmaOp::kM;
|
||||
constexpr static uint32_t FragN = MmaOp::kN;
|
||||
constexpr static uint32_t FragK = MmaOp::kK;
|
||||
|
||||
// Fragment counts for decomposition
|
||||
constexpr static uint32_t FragsM = WaveTileM / FragM;
|
||||
constexpr static uint32_t FragsN = WaveTileN / FragN;
|
||||
constexpr static uint32_t FragsK = WaveTileK / FragK;
|
||||
constexpr static uint32_t FragsC = FragsM * FragsN;
|
||||
|
||||
// Vector types for packed registers in each fragment
|
||||
using InternalAVecT = typename MmaOp::AVecType;
|
||||
using InternalBVecT = typename MmaOp::BVecType;
|
||||
using InternalCVecT = typename MmaOp::CVecType;
|
||||
|
||||
// Buffer types for WaveTiles
|
||||
using AVecType = InternalAVecT[FragsM][FragsK];
|
||||
using BVecType = InternalBVecT[FragsN][FragsK];
|
||||
using CVecType = InternalCVecT[FragsM][FragsN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM");
|
||||
static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN");
|
||||
static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK");
|
||||
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
|
||||
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
|
||||
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
|
||||
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static void execImpl(std::tuple<VecTA, VecTB, VecTC>& vecs)
|
||||
{
|
||||
auto& [a_frag, b_frag, c_frag] = vecs;
|
||||
|
||||
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
|
||||
{
|
||||
// "Row-major" accumulation over the N-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the fragments in row-major
|
||||
// order. We also have to ensure that the incoming vector WaveTiles are converted to
|
||||
// native vector types before passing to the FragWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
|
||||
{
|
||||
// "Col-major" accumulation over the M-dimension fragments first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in col-major order
|
||||
for(uint32_t bn = 0u; bn < FragsN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < FragsM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
229
include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp
Normal file
229
include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
149
include/ck_tile/core/arch/mma/scale/mfma/selector.hpp
Normal file
149
include/ck_tile/core/arch/mma/scale/mfma/selector.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class ScaleMfmaDefaultSelector
|
||||
* @brief Implements a default scale MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam WaveTileKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
|
||||
// is_power_of_two_integer(WaveTileKTest))
|
||||
struct ScaleMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SCALE>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for scale MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
amdgcn_target_id::GFX950)>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SCALE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
static constexpr bool IsSupported32x32 =
|
||||
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
10
include/ck_tile/core/arch/mma/scale/scale.hpp
Normal file
10
include/ck_tile/core/arch/mma/scale/scale.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include scale MFMA traits and architecture-specific implementations
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
77
include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp
Normal file
77
include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t FragM,
|
||||
std::uint32_t FragN,
|
||||
std::uint32_t FragK,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp_ =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp_ = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SCALE>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
|
||||
using MmaOp = MmaOp_; // Expose the selected MmaOp
|
||||
|
||||
// Expose caller-side vector types
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Expose internal vector types
|
||||
using InternalAVecT = typename MmaOp::AVecType;
|
||||
using InternalBVecT = typename MmaOp::BVecType;
|
||||
using InternalCVecT = typename MmaOp::CVecType;
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
template <typename VecTA,
|
||||
typename VecTB,
|
||||
typename VecTC,
|
||||
typename ScaleADataType,
|
||||
typename ScaleBDataType>
|
||||
CK_TILE_DEVICE static void
|
||||
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
|
||||
{
|
||||
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
|
||||
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
6
include/ck_tile/core/arch/mma/scale/scale_selector.hpp
Normal file
6
include/ck_tile/core/arch/mma/scale/scale_selector.hpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
||||
93
include/ck_tile/core/arch/mma/scale/scale_traits.hpp
Normal file
93
include/ck_tile/core/arch/mma/scale/scale_traits.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
// #include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace scale::detail {
|
||||
|
||||
template <typename T>
|
||||
struct ScaleDataTypeToFlag;
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<fp8_t> // e4m3
|
||||
{
|
||||
static constexpr std::int32_t value = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<bf8_t> // e5m2
|
||||
{
|
||||
static constexpr std::int32_t value = 1;
|
||||
};
|
||||
|
||||
// template <>
|
||||
// struct ScaleDataTypeToFlag<pk_fp6_t<1>> // e2m3
|
||||
// {
|
||||
// static constexpr std::int32_t value = 2;
|
||||
// };
|
||||
|
||||
// template <>
|
||||
// struct ScaleDataTypeToFlag<bf6_t> // e3m2
|
||||
// {
|
||||
// static constexpr std::int32_t value = 3;
|
||||
// };
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
|
||||
{
|
||||
static constexpr std::int32_t value = 4;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept ScaleMfmaDataTypeToFlag
|
||||
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
|
||||
*/
|
||||
template <typename DataTypeToFlag>
|
||||
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
|
||||
// Flag members for scale MFMA instructions
|
||||
{ DataTypeToFlag::value } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
|
||||
|
||||
} // namespace scale::detail
|
||||
|
||||
struct DefaultScaleMfmaCtrlFlags
|
||||
{
|
||||
static constexpr std::int32_t OPSEL_A = 0;
|
||||
static constexpr std::int32_t OPSEL_B = 0;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept ScaleMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept ScaleMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for scale MFMA instructions
|
||||
{ CtrlFlags::OPSEL_A } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::OPSEL_B } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
43
include/ck_tile/core/arch/mma/scale/scale_transforms.hpp
Normal file
43
include/ck_tile/core/arch/mma/scale/scale_transforms.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsScale
|
||||
* @brief Implements the default MMA transforms for Scale
|
||||
*/
|
||||
struct MmaDefaultTransformsScale
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Specialization for Scale MFMA transforms
|
||||
* Provides default transform selection for scale operations
|
||||
*
|
||||
* @tparam MmaOp Scale MMA operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_mma_op_scale(MmaOp))
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SCALE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsScale;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -31,25 +30,12 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
{
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
100
include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp
Normal file
100
include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace sparse::detail {
|
||||
// TODO: c++20: return MmaPipelineOptionFlags directly
|
||||
constexpr inline int getPipelineFlags()
|
||||
{
|
||||
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A);
|
||||
}
|
||||
} // namespace sparse::detail
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp_ =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
|
||||
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
|
||||
"Cannot transpose C in sparse intrinsics.");
|
||||
|
||||
using MmaOp = MmaOp_; // Expose the selected MmaOp
|
||||
|
||||
// Calculate the uncompressed A vector type
|
||||
struct ExternalAVecCalculator
|
||||
{
|
||||
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
|
||||
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
|
||||
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
|
||||
};
|
||||
|
||||
// Expose caller-side vector types
|
||||
using AVecType = typename ExternalAVecCalculator::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Expose internal vector types
|
||||
using InternalAVecT = typename MmaOp::AVecType;
|
||||
using InternalBVecT = typename MmaOp::BVecType;
|
||||
using InternalCVecT = typename MmaOp::CVecType;
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
|
||||
CK_TILE_DEVICE static void
|
||||
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
|
||||
{
|
||||
checkATransformResult<ATransformResult>();
|
||||
auto& [a_result, b_vec, c_vec] = vecs;
|
||||
auto& [a_vec, idx] = a_result;
|
||||
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
|
||||
}
|
||||
|
||||
private:
|
||||
// Type check helper - not a device function, so std::declval is available
|
||||
template <typename ATransformResult>
|
||||
static constexpr void checkATransformResult()
|
||||
{
|
||||
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
|
||||
static_assert(std::is_same_v<ATransformResult,
|
||||
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -43,18 +43,15 @@ struct BuiltinParams
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
BuiltinParams params;
|
||||
// TODO c++20: designated initializers
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
params.UseFirstIndex = 1;
|
||||
params.ByteIndexToOverride = 0;
|
||||
return BuiltinParams{1, 0};
|
||||
}
|
||||
else
|
||||
{
|
||||
params.UseFirstIndex = 0;
|
||||
params.ByteIndexToOverride = static_cast<int>(Idx);
|
||||
return BuiltinParams{0, static_cast<int>(Idx)};
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
@@ -6,22 +6,101 @@
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace sparse::detail {
|
||||
/**
|
||||
* @struct MmaDefaultTransformsSparse
|
||||
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
* elements into lower part of a_vec to half its effective size.
|
||||
* @param a_vec Vector to be compressed.
|
||||
* @tparam ADataType The data type of a_vec
|
||||
* @tparam CompressedSize The target compression size
|
||||
* @tparam AVec The vector type of a_vec (deduced)
|
||||
* @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields.
|
||||
* Each field encodes the original position (0–3) of the corresponding
|
||||
* non‑zero element in the input. If fewer than CompressedSize
|
||||
* non‑zeros are found, remaining fields default to 2 (see below).
|
||||
*/
|
||||
template <typename ADataType, index_t CompressedSize, typename AVec>
|
||||
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
|
||||
{
|
||||
// idx holds one 2‑bit index per output element (total CompressedSize entries).
|
||||
// It is initialized to the pattern 0b10 for every field. This matches
|
||||
// what the hardware expects when there are fewer than two non‑zero values
|
||||
// in a 4‑element group – the unused output is treated as coming from slot 2.
|
||||
// The loop below will clear and set each field as real non‑zeros are seen.
|
||||
int32_t idx = 0;
|
||||
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
|
||||
|
||||
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto j) {
|
||||
if(static_cast<float>(a_vec[i * 4 + j]) != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
// clear the two‑bit field for this output and insert j
|
||||
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
|
||||
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
} // namespace sparse::detail
|
||||
|
||||
/**
|
||||
* @class SparseCompressTransform
|
||||
* @brief Performs 2:4 structured sparsity compression to the vector v and produces an index mask.
|
||||
* @note Returns a tuple of two. The first element is the vector v with the same scalar type but
|
||||
* its size halved. The second element is the index mask.
|
||||
*/
|
||||
template <index_t CompressionRatio>
|
||||
struct SparseCompressTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType& v)
|
||||
{
|
||||
using VecTraits = vector_traits<remove_cvref_t<VecType>>;
|
||||
using ScalarT = typename VecTraits::scalar_type;
|
||||
static constexpr auto VecN = VecTraits::vector_size;
|
||||
static constexpr index_t CompressedSize = VecN / CompressionRatio;
|
||||
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
|
||||
|
||||
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
|
||||
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
|
||||
|
||||
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
|
||||
|
||||
// TODO c++20: Use bit_cast
|
||||
return std::tuple<VecCompressed&, int32_t>(
|
||||
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaDefaultTransformsSparse
|
||||
* @brief Implements the default transforms for Sparse
|
||||
*
|
||||
* For 2:4 structured sparsity with inline register metadata:
|
||||
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
|
||||
* - ATransform: 2:4 structured sparsity compression
|
||||
* - BTransform: Pass-through (sparse operands already formatted)
|
||||
* - CTransform: Pass-through (input accumulator)
|
||||
* - DTransform: Pass-through (output accumulator as-is)
|
||||
*/
|
||||
template <index_t CompressionRatio>
|
||||
struct MmaDefaultTransformsSparse
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using ATransform = SparseCompressTransform<CompressionRatio>;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
@@ -42,7 +121,7 @@ struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsSparse;
|
||||
using SelectedTransforms = MmaDefaultTransformsSparse<MmaOp::kCompressionRatio>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -21,23 +20,9 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
{
|
||||
static constexpr index_t ABVecN = vector_traits<AVecType>::vector_size;
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const uint32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
/**
|
||||
* @class TileDistrEncCalc
|
||||
* @brief Given an MmaOp and modifiers, provides warp-level tile distribution encodings for mapping
|
||||
* ABC matrix fragment coordinates to register coordinates (lane, vector item) and vice versa.
|
||||
* @tparam MmaOp Intrinsic (amdgcn_mma).
|
||||
* @tparam CTranspose Whether we are using CTranspose.
|
||||
* @tparam SFactor Swizzle factor. Not implemented.
|
||||
* @tparam AttrNumAccessA Requested NumAccess for the A matrix. Must be multiple of "fundamental"
|
||||
* NumAccess for intrinsic. See details in amdgcn_mma.hpp.
|
||||
* @tparam AttrNumAccessB Requested NumAccess for the B matrix.
|
||||
*/
|
||||
template <typename MmaOp,
|
||||
bool CTranspose = false,
|
||||
index_t SFactor = 1,
|
||||
index_t AttrNumAccessA = MmaOp::kAKNumAccess,
|
||||
index_t AttrNumAccessB = MmaOp::kBKNumAccess>
|
||||
struct TileDistrEncCalc
|
||||
{
|
||||
private:
|
||||
static constexpr index_t NumAccessA = std::max(MmaOp::kAKNumAccess, AttrNumAccessA);
|
||||
static constexpr index_t NumAccessB = std::max(MmaOp::kBKNumAccess, AttrNumAccessB);
|
||||
|
||||
// We are free to choose any NumAccess value to manipulate the load / store behavior, unless the
|
||||
// intrinsic fundamentally requires a base NumAccess factor for the layout to be correct.
|
||||
static_assert(AttrNumAccessA % MmaOp::kAKNumAccess == 0,
|
||||
"Requesting NumAccessA incompatible with builtin.");
|
||||
static_assert(AttrNumAccessB % MmaOp::kBKNumAccess == 0,
|
||||
"Requesting NumAccessB incompatible with builtin.");
|
||||
|
||||
static_assert(MmaOp::kABKPerLane % NumAccessA == 0);
|
||||
static_assert(MmaOp::kABKPerLane % NumAccessB == 0);
|
||||
static_assert(SFactor == 1, "Swizzle not implemented yet."); // TODO: Implement Swizzle.
|
||||
|
||||
template <index_t MajorDimSize, index_t Repeat, index_t NumAccess, index_t CompressionRatio = 1>
|
||||
using ABWarpDstrEnc = tile_distribution_encoding<
|
||||
sequence<Repeat>,
|
||||
tuple<sequence<MajorDimSize>,
|
||||
sequence<NumAccess,
|
||||
MmaOp::kK / MmaOp::kABKPerLane,
|
||||
MmaOp::kABKPerLane / NumAccess / CompressionRatio>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
// We unmerge the M and N dimensions in the same way every time.
|
||||
using MSubDims = sequence<MmaOp::kCMBlocks,
|
||||
MmaOp::kCMNumAccess,
|
||||
MmaOp::kM / MmaOp::kCMBlocks / MmaOp::kCMPerLane,
|
||||
MmaOp::kCMPerLane / MmaOp::kCMNumAccess>;
|
||||
using NSubDims = sequence<MmaOp::kCNBlocks, MmaOp::kN / MmaOp::kCNBlocks>;
|
||||
|
||||
// In case of CTranspose, all we do is swap the M and N dimension.
|
||||
using MatDims =
|
||||
std::conditional_t<CTranspose, tuple<NSubDims, MSubDims>, tuple<MSubDims, NSubDims>>;
|
||||
constexpr int MInx = CTranspose ? 2 : 1;
|
||||
constexpr int NInx = CTranspose ? 1 : 2;
|
||||
|
||||
// For MFMA intrinsics with blocks, the block dimensions might be in the Lane dim or in the
|
||||
// Vec dim, so we get different merge orderings.
|
||||
if constexpr(MmaOp::CBlockDimInVecDim)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<1>,
|
||||
MatDims,
|
||||
tuple<sequence<MInx, NInx>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
sequence<MInx, NInx, MInx, MInx>,
|
||||
sequence<0, 0, 1, 3>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<sequence<1>,
|
||||
MatDims,
|
||||
tuple<sequence<MInx, NInx, MInx, NInx>>,
|
||||
tuple<sequence<0, 0, 2, 1>>,
|
||||
sequence<MInx, MInx>,
|
||||
sequence<1, 3>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AEnc_ = ABWarpDstrEnc<MmaOp::kM, MmaOp::kARepeat, NumAccessA, MmaOp::kCompressionRatio>;
|
||||
using BEnc_ = ABWarpDstrEnc<MmaOp::kN, MmaOp::kBRepeat, NumAccessB>;
|
||||
|
||||
public:
|
||||
// When using CTranspose, the A and B matrices are swapped.
|
||||
using AWarpDstrEncoding = std::conditional_t<CTranspose, BEnc_, AEnc_>;
|
||||
using BWarpDstrEncoding = std::conditional_t<CTranspose, AEnc_, BEnc_>;
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
// Some additional consistency checks
|
||||
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_lanes == MmaOp::WaveSize);
|
||||
|
||||
static_assert(TileDistrEncRegMap<AWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::AVecType>::vector_size);
|
||||
static_assert(TileDistrEncRegMap<BWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::BVecType>::vector_size);
|
||||
static_assert(TileDistrEncRegMap<CWarpDstrEncoding>::num_vector_items ==
|
||||
vector_traits<typename MmaOp::CVecType>::vector_size);
|
||||
};
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
struct MmaTransformsDefaultSelector<
|
||||
MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_all<enable_if_target_family_gfx11_t<CompilerTarget>,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx11;
|
||||
};
|
||||
@@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector<MmaOp,
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct MmaTransformsDefaultSelector<
|
||||
MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx12;
|
||||
};
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_CNAN 5
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD
|
||||
@@ -209,6 +210,17 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm >= 7.12)
|
||||
// Philox RNG VGPR parameters get aliased under high register pressure (d256 tile).
|
||||
// fp16 is affected; bf16 is not (different type conversion codegen path).
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE
|
||||
#if(HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR >= 12) || (HIP_VERSION_MAJOR > 7)
|
||||
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
@@ -84,19 +84,6 @@ struct array
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
@@ -247,13 +234,6 @@ CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&...
|
||||
return {std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
// // make empty array
|
||||
// template <typename T>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
// {
|
||||
// return array<T, 0>{};
|
||||
// }
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
template <typename T, index_t Size>
|
||||
|
||||
@@ -480,32 +480,6 @@ struct sequence_split
|
||||
using right_type = decltype(Seq::extract(range1{}));
|
||||
};
|
||||
|
||||
#if 0
|
||||
// reverse sequence
|
||||
template <typename Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.size();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::right_type>::type,
|
||||
typename sequence_reverse<typename seq_split::left_type>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct sequence_reverse<sequence<I>>
|
||||
{
|
||||
using type = sequence<I>;
|
||||
};
|
||||
|
||||
template <index_t I0, index_t I1>
|
||||
struct sequence_reverse<sequence<I0, I1>>
|
||||
{
|
||||
using type = sequence<I1, I0>;
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template <typename Id, index_t... Ns>
|
||||
struct seq_reverse;
|
||||
|
||||
@@ -24,18 +24,4 @@ using statically_indexed_array = array<T, N>;
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
#if 0
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty statically_indexed_array
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -23,18 +23,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
}
|
||||
#else
|
||||
|
||||
#if 0
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template<typename T_, index_t N_>
|
||||
struct thread_buffer {
|
||||
@@ -103,25 +91,6 @@ struct thread_buffer {
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data;
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx {x};
|
||||
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
|
||||
@@ -292,9 +292,6 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
|
||||
@@ -333,13 +330,6 @@ struct vector_traits<tuple<T...>, void>
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
};
|
||||
|
||||
// template <class... T>
|
||||
// CK_TILE_HOST_DEVICE constexpr
|
||||
// tuple<T...>
|
||||
// make_tuple(T const&... t)
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
|
||||
@@ -22,7 +22,8 @@ enum class bf16_rounding_mode
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
rta_asm, // round to nearest away
|
||||
standard_cnan, // rtn with canonical NaN
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
@@ -226,6 +227,39 @@ uint16_t float_to_bf16_rta_asm(float f)
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr bool float_is_nan_raw(float f)
|
||||
{
|
||||
#if defined(__has_builtin) && __has_builtin(__builtin_isnan)
|
||||
return __builtin_isnan(f);
|
||||
#else
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
constexpr uint32_t exp_mask = 0x7f800000;
|
||||
constexpr uint32_t mant_mask = 0x007fffff;
|
||||
|
||||
return (bits & exp_mask) == exp_mask && (bits & mant_mask);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Round to nearest even, but canonicalize any NaN input to the canonical quiet bf16 NaN
|
||||
// (`0x7fff`). Unlike `float_to_bf16_rtn_raw`, this does not preserve signaling NaN
|
||||
// payload/state.
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_cnan_raw(float f)
|
||||
{
|
||||
#if defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && __FINITE_MATH_ONLY__)
|
||||
// Fast/finite-math can fold the NaN predicate away, so fall back to standard RTN.
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
#else
|
||||
// `-fgpu-flush-denormals-to-zero` only affects denormals, not NaN handling.
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
uint32_t tmp = (bits >> 16) & 1;
|
||||
uint32_t res = float_is_nan_raw(f) ? 0x7fff0000 : bits + tmp + 0x7fff;
|
||||
|
||||
return uint16_t(res >> 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
@@ -249,6 +283,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_cnan)
|
||||
return float_to_bf16_rtn_cnan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
|
||||
@@ -264,93 +264,6 @@ bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.t
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
#endif
|
||||
|
||||
@@ -73,27 +73,6 @@ struct numeric<int8_t>
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
|
||||
template <>
|
||||
struct numeric_traits<int8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
|
||||
|
||||
|
||||
@@ -295,10 +295,6 @@ struct tile_sweeper
|
||||
F f;
|
||||
};
|
||||
|
||||
// partial deduction is not allowed
|
||||
// template <typename T, typename F, typename U>
|
||||
// tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
// deduction guide
|
||||
template <typename T,
|
||||
typename F,
|
||||
|
||||
@@ -236,12 +236,13 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
namespace detail {
|
||||
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
|
||||
long_index_t acc_new = acc_old + static_cast<long_index_t>(lengths[i] - number<1>{}) *
|
||||
static_cast<long_index_t>(strides[i]);
|
||||
|
||||
if constexpr(i.value < Lengths::size() - 1)
|
||||
{
|
||||
@@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size =
|
||||
const long_index_t element_space_size_long =
|
||||
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
|
||||
constexpr long_index_t element_space_size_clamp_value =
|
||||
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
|
||||
const index_t element_space_size =
|
||||
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
@@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
const auto element_space_size_long = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
constexpr long_index_t element_space_size_clamp_value =
|
||||
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
|
||||
const index_t element_space_size =
|
||||
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
|
||||
|
||||
|
||||
@@ -454,45 +454,6 @@ struct tile_distribution_detail
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#if 0
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
#endif
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
|
||||
|
||||
@@ -209,4 +209,21 @@ template <typename ADataType, typename BDataType>
|
||||
using largest_type_t =
|
||||
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
/**
|
||||
* @brief Type trait to detect whether a type is a @c std::tuple specialization.
|
||||
* @tparam T The type to inspect.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_std_tuple : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
struct is_std_tuple<std::tuple<Args...>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_std_tuple_v = is_std_tuple<T>::value;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -25,6 +25,7 @@ template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name
|
||||
template <> struct DataTypeTraits<pk_fp6x16_t> { static constexpr const char * name = "pk_fp6x16"; };
|
||||
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
|
||||
template <> struct DataTypeTraits<e8m0_t> { static constexpr const char * name = "e8m0"; };
|
||||
template <> struct DataTypeTraits<ck_tile::tf32_t>{ static constexpr const char* name = "tf32"; };
|
||||
|
||||
template <memory_operation_enum MemOp> struct memOpToStr;
|
||||
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
|
||||
|
||||
@@ -27,10 +27,13 @@ struct ElementWiseKernel
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
template <typename... XDataType, typename Dims>
|
||||
CK_TILE_DEVICE void operator()(const Dims lens,
|
||||
const Dims input_strides,
|
||||
const Dims output_strides,
|
||||
template <typename... XDataType,
|
||||
typename DimsLens,
|
||||
typename DimsInStrides,
|
||||
typename DimsOutStrides>
|
||||
CK_TILE_DEVICE void operator()(const DimsLens lens,
|
||||
const DimsInStrides input_strides,
|
||||
const DimsOutStrides output_strides,
|
||||
const tuple<XDataType...>& input_tensors,
|
||||
YDataType* p_y) const
|
||||
{
|
||||
@@ -49,10 +52,11 @@ struct ElementWiseKernel
|
||||
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
|
||||
|
||||
const auto transformed_tensor = pad_tensor_view(
|
||||
transform_tensor_view(tensor_view,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
transform_tensor_view(
|
||||
tensor_view,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<DimsLens::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
@@ -86,13 +90,14 @@ struct ElementWiseKernel
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, lens, output_strides, number<S::kVectorM>{});
|
||||
|
||||
const auto transformed_y_m_n = pad_tensor_view(
|
||||
transform_tensor_view(y_m_n,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
const auto transformed_y_m_n =
|
||||
pad_tensor_view(transform_tensor_view(
|
||||
y_m_n,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<DimsOutStrides::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
auto y_window = make_tile_window(transformed_y_m_n,
|
||||
make_tuple(number<S::kBlockM>{}),
|
||||
|
||||
@@ -745,14 +745,6 @@ struct PassThroughPack2
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<fp16x2_t>(t);
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
|
||||
{
|
||||
uint8_t x_u8 = bit_cast<uint8_t>(x);
|
||||
@@ -871,61 +863,6 @@ struct UnaryConvert
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = bf16_convert_rtn<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8RNE
|
||||
{
|
||||
// convert to fp8 using rounding to nearest even
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_rne<Y>(x);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/permuten_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
@@ -33,7 +33,6 @@ template <typename AsDataType_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
@@ -59,7 +58,6 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -658,152 +656,8 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /* p_smem */,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename ScaleDataType<ScaleM>::DataType;
|
||||
using SNType = typename ScaleDataType<ScaleN>::DataType;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if non-scalar scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
// Slice accumulators for this M repeat into the permuted layout
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If non-scalar scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
}
|
||||
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
|
||||
375
include/ck_tile/ops/epilogue/permuten_epilogue.hpp
Normal file
375
include/ck_tile/ops/epilogue/permuten_epilogue.hpp
Normal file
@@ -0,0 +1,375 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct PermuteNEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct PermuteNEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
|
||||
std::is_same_v<ADataType, pk_fp4_t>,
|
||||
BDataType,
|
||||
ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
// PermuteN epilogue does not support D tensors or non-passthrough elementwise operations.
|
||||
// If D tensor support is needed, use CShuffleEpilogue instead.
|
||||
static_assert(NumDTensor == 0,
|
||||
"PermuteNEpilogue does not support D tensors. Use CShuffleEpilogue instead.");
|
||||
static_assert(std::is_same_v<CDElementwise, element_wise::PassThrough>,
|
||||
"PermuteNEpilogue only supports PassThrough elementwise. "
|
||||
"Use CShuffleEpilogue for custom elementwise operations.");
|
||||
|
||||
CK_TILE_DEVICE PermuteNEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "PermuteNEpilogue",
|
||||
concat('x', MWave, NWave),
|
||||
concat('x', MPerXdl, NPerXdl, KPerXdl),
|
||||
VectorSizeC,
|
||||
isCTransposed ? "CTransposed" : "CNotTransposed");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
|
||||
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
|
||||
struct EmptyScale
|
||||
{
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct ScaleDataType
|
||||
{
|
||||
using DataType = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScaleDataType<T, std::void_t<typename T::DataType>>
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
};
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /* p_smem */,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"PermuteN: MPerXdl must be divisible by per-lane row count.");
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename ScaleDataType<ScaleM>::DataType;
|
||||
using SNType = typename ScaleDataType<ScaleN>::DataType;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if non-scalar scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
// Slice accumulators for this M repeat into the permuted layout
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If non-scalar scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
}
|
||||
|
||||
// Pack "rows per lane" with permuted N layout
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// Fuse scale (if present) and convert
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -339,16 +339,6 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
|
||||
@@ -483,13 +483,6 @@ struct MoeFlatmmKernel
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
||||
// {
|
||||
// std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
// " without padding!"
|
||||
// << std::endl;
|
||||
// return false;
|
||||
// }
|
||||
if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
|
||||
@@ -392,10 +392,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
|
||||
@@ -1151,11 +1151,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
// {
|
||||
// block_sync_lds();
|
||||
// }
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -1636,10 +1631,6 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,13 +103,8 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
// static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread *
|
||||
// ContinuousScaleKPerThread; static constexpr index_t ScaleBload_num =
|
||||
// kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
// WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
@@ -352,10 +347,6 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,10 +390,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
const index_t seqlen_k = [&]() {
|
||||
// WA i_batch capture structure binding before c++20
|
||||
const index_t seqlen_k = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch_];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1];
|
||||
const int32_t num_page_blocks = page_end - page_start;
|
||||
const int32_t last_page_len = [&]() {
|
||||
if constexpr(kPageBlockSize == 1)
|
||||
return static_cast<int32_t>(kPageBlockSize);
|
||||
else
|
||||
return kargs.page_table.kv_last_page_lens[i_batch];
|
||||
return kargs.page_table.kv_last_page_lens[i_batch_];
|
||||
}();
|
||||
return num_page_blocks > 0
|
||||
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
|
||||
@@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
if(kargs.page_table.seqlen_k_ptr != nullptr)
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch_]);
|
||||
else
|
||||
return kargs.seqlen_k;
|
||||
}
|
||||
}();
|
||||
const int32_t* page_idx = [&]() {
|
||||
// WA i_batch capture structure binding before c++20
|
||||
const int32_t* page_idx = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_];
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
return kargs.page_table.block_table_ptr +
|
||||
static_cast<long_index_t>(i_batch) *
|
||||
static_cast<long_index_t>(i_batch_) *
|
||||
kargs.page_table.batch_stride_block_table;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -39,6 +39,9 @@ struct FmhaFwdKernel
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
|
||||
template <typename T>
|
||||
using has_hdim_tail_args = decltype(T::kUseHdimTailArgs);
|
||||
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
@@ -1891,6 +1894,35 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
|
||||
constexpr bool kPassHdimTailArgs = [] {
|
||||
if constexpr(ck_tile::is_detected<has_hdim_tail_args, FmhaPipeline>::value)
|
||||
return static_cast<bool>(FmhaPipeline::kUseHdimTailArgs);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
auto invoke_fmha_pipeline = [&](auto&&... args) -> decltype(auto) {
|
||||
if constexpr(kPassHdimTailArgs)
|
||||
{
|
||||
const ck_tile::index_t valid_k0_loops =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0);
|
||||
const ck_tile::index_t valid_last_k0_length =
|
||||
kargs.hdim_q - (valid_k0_loops - 1) * FmhaPipeline::kK0;
|
||||
const ck_tile::index_t valid_n1_length = [&]() {
|
||||
const ck_tile::index_t remaining_n1 = kargs.hdim_v - i_n1;
|
||||
return ck_tile::min(remaining_n1,
|
||||
static_cast<ck_tile::index_t>(FmhaPipeline::kN1));
|
||||
}();
|
||||
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)...,
|
||||
sink_value,
|
||||
valid_k0_loops,
|
||||
valid_last_k0_length,
|
||||
valid_n1_length);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)..., sink_value);
|
||||
}
|
||||
};
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
@@ -1910,36 +1942,35 @@ struct FmhaFwdKernel
|
||||
else
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()));
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
@@ -1964,7 +1995,7 @@ struct FmhaFwdKernel
|
||||
// Both P and rowsum are scaled by 2^shift, canceling in normalization
|
||||
// No additional scaling needed in p_compute_element_func or o_acc_element_func
|
||||
|
||||
return FmhaPipeline{}(
|
||||
return invoke_fmha_pipeline(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
@@ -1992,8 +2023,7 @@ struct FmhaFwdKernel
|
||||
kargs.block_scale_size_kv,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
make_null_tile_window(make_tuple()));
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
@@ -2098,53 +2128,51 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
|
||||
{i_n1, 0});
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
q_scale_dram_window,
|
||||
k_scale_dram_window,
|
||||
v_scale_dram_window,
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
q_scale_dram_window,
|
||||
k_scale_dram_window,
|
||||
v_scale_dram_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
|
||||
@@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
@@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
|
||||
// kPageBlockSize < kN0: global offset, must fit int32
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
index_t current_seq_k = kv_load_start;
|
||||
|
||||
// Load physical pages first, then compute offsets.
|
||||
// k_physical_pages can be reused for descale lookup later.
|
||||
@@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
@@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
v_dist,
|
||||
v_offsets,
|
||||
number<1>{}, // HsGatherDim
|
||||
@@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == num_sink_loop - 1)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(
|
||||
std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window,
|
||||
{0, seqlen_k_start - sink_seq_end});
|
||||
return in_sink_phase
|
||||
? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}
|
||||
else
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
}();
|
||||
dropout
|
||||
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
@@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
current_seq_k += kN0;
|
||||
// For sink: after the last sink tile, jump K/V to seqlen_k_start;
|
||||
// otherwise advance by one normal tile.
|
||||
const index_t k_advance = [&]() -> index_t {
|
||||
if constexpr(kHasSink)
|
||||
return (i_total_loops == num_sink_loop)
|
||||
? (seqlen_k_start - sink_seq_end + kN0)
|
||||
: kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
current_seq_k += k_advance;
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
move_tile_window(k_dram_block_window, {k_advance, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
// KV_BLOCKSCALE: reload physical pages for the new tile
|
||||
@@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
// After sink→window transition (i_total_loops == num_sink_loop), V window
|
||||
// was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance
|
||||
// = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k.
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == num_sink_loop && num_sink_loop > 0)
|
||||
{
|
||||
prefetch_v_physical_pages(number<0>{});
|
||||
update_v_offsets(number<0>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
QRKSVS_ASYNC_TRLOAD_V3,
|
||||
QRKSVS_HPAD,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
|
||||
{
|
||||
static constexpr const char* name = "qr_hpad";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
|
||||
bool PaddedVecLoadStore_ = false>
|
||||
struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
@@ -37,6 +39,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
template <typename T>
|
||||
using has_partial_k_support = decltype(T::kSupportsPartialK);
|
||||
template <typename T>
|
||||
using has_partial_n_support = decltype(T::kSupportsPartialN);
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
@@ -54,17 +61,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
|
||||
static constexpr bool kUseHdimTailArgs = kPadHeadDimQ || kPadHeadDimV;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
|
||||
@@ -80,23 +89,29 @@ struct BlockFmhaPipelineQRKSVS
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
|
||||
"padded vector load/store fast path only applies to padded head-dim kernels");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
return (kPadHeadDimV && !kPaddedVecLoadStore)
|
||||
? 1
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
@@ -194,7 +209,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
|
||||
const VScaleDramBlockWindowTmp&
|
||||
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
const index_t valid_k0_loops,
|
||||
const index_t valid_last_k0_length,
|
||||
const index_t valid_n1_length) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -252,8 +270,30 @@ struct BlockFmhaPipelineQRKSVS
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
using BlockGemm1 = remove_cvref_t<decltype(gemm_1)>;
|
||||
constexpr bool kBlockGemm0SupportsPartialK = [] {
|
||||
if constexpr(ck_tile::is_detected<has_partial_k_support, BlockGemm0>::value)
|
||||
return static_cast<bool>(BlockGemm0::kSupportsPartialK);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
constexpr bool kBlockGemm1SupportsPartialN = [] {
|
||||
if constexpr(ck_tile::is_detected<has_partial_n_support, BlockGemm1>::value)
|
||||
return static_cast<bool>(BlockGemm1::kSupportsPartialN);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
|
||||
constexpr auto gemm_0_config =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using Gemm0WarpGemm = remove_cvref_t<decltype(gemm_0_config.template at<0>())>;
|
||||
constexpr index_t kGemm0WarpK = Gemm0WarpGemm::kK;
|
||||
constexpr index_t kGemm0KItersPerBlock = kK0 / kGemm0WarpK;
|
||||
constexpr bool kUsePartialKForGemm0Tail =
|
||||
kPadHeadDimQ && kBlockGemm0SupportsPartialK && (kGemm0KItersPerBlock > 1);
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -419,10 +459,26 @@ struct BlockFmhaPipelineQRKSVS
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
// Number of k0 iterations prefetched ahead of the current compute iteration.
|
||||
// The skip decision must be made this many iterations before the last k0 loop.
|
||||
constexpr index_t kK0PrefetchDepth = 2;
|
||||
const index_t gemm0_tail_k_iters = [&]() {
|
||||
if constexpr(kUsePartialKForGemm0Tail)
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(valid_last_k0_length, kGemm0WarpK);
|
||||
}
|
||||
return static_cast<index_t>(kGemm0KItersPerBlock);
|
||||
}();
|
||||
const bool skip_last_k0_loop = [&]() {
|
||||
if constexpr(kPadHeadDimQ)
|
||||
{
|
||||
return valid_k0_loops == (k0_loops - 1);
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm_0 = [] {
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
@@ -447,7 +503,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(kK0PrefetchDepth <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
@@ -514,6 +570,46 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
|
||||
auto run_gemm_0 = [&](auto i_k0) {
|
||||
if constexpr(kUsePartialKForGemm0Tail)
|
||||
{
|
||||
if(static_cast<index_t>(i_k0.value) == (valid_k0_loops - 1) &&
|
||||
gemm0_tail_k_iters < kGemm0KItersPerBlock)
|
||||
{
|
||||
static_for<1, kGemm0KItersPerBlock, 1>{}([&](auto i_tail_k_iter) {
|
||||
constexpr index_t kTailKIters = i_tail_k_iter;
|
||||
constexpr index_t kTailK0 = kTailKIters * kGemm0WarpK;
|
||||
|
||||
if(gemm0_tail_k_iters == kTailKIters)
|
||||
{
|
||||
using Gemm0TailProblem = BlockGemmProblem<
|
||||
QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<
|
||||
sequence<kM0, kN0, kTailK0>,
|
||||
typename BlockFmhaShape::Gemm0BlockWarps,
|
||||
sequence<BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
kGemm0WarpK>>>;
|
||||
constexpr auto gemm_0_tail =
|
||||
BlockGemmARegBSmemCRegV2<Gemm0TailProblem,
|
||||
typename BlockGemm0::Policy>{};
|
||||
|
||||
auto q_slice =
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, i_k0 * kK0 + kTailK0>{});
|
||||
auto k_tail_window = make_tile_window(
|
||||
k_lds, make_tuple(number<kN0>{}, number<kTailK0>{}), {0, 0});
|
||||
|
||||
gemm_0_tail(s_acc, q_slice, k_tail_window);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto q_slice = get_slice_tile(
|
||||
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{});
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
@@ -531,36 +627,78 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
if constexpr(k0_loops > kK0PrefetchDepth)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
static_for<0, k0_loops - kK0PrefetchDepth, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<i_k0>{});
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
if constexpr(kPadHeadDimQ && i_k0 == (k0_loops - 1 - kK0PrefetchDepth))
|
||||
{
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
}
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - 2>{});
|
||||
auto v_prefetch = decltype(load_tile(v_dram_window)){};
|
||||
enum class VPrefetchPoint
|
||||
{
|
||||
BeforeGemm0Tail,
|
||||
AfterGemm0Tail,
|
||||
AfterSoftmax
|
||||
};
|
||||
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
constexpr auto kVPrefetch =
|
||||
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
|
||||
#else
|
||||
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
|
||||
#endif
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
|
||||
}
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - kK0PrefetchDepth>{});
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
}
|
||||
}
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
@@ -819,6 +957,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -898,6 +1041,31 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
constexpr auto gemm_1_config =
|
||||
BlockGemm1::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using Gemm1WarpGemm = remove_cvref_t<decltype(gemm_1_config.template at<0>())>;
|
||||
constexpr index_t kGemm1NWarp = gemm_1_config.template at<2>();
|
||||
constexpr index_t kGemm1NPerIter = kGemm1NWarp * Gemm1WarpGemm::kN;
|
||||
const index_t valid_n_iters = [&]() {
|
||||
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(valid_n1_length, kGemm1NPerIter);
|
||||
}
|
||||
return static_cast<index_t>(0);
|
||||
}();
|
||||
|
||||
auto run_gemm_1_impl =
|
||||
[&](auto& o_acc_tensor, const auto& p_slice, const auto&... gemm_1_args) {
|
||||
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
|
||||
{
|
||||
gemm_1(o_acc_tensor, p_slice, gemm_1_args..., valid_n_iters);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc_tensor, p_slice, gemm_1_args...);
|
||||
}
|
||||
};
|
||||
|
||||
auto run_gemm_1 = [&](auto i_k1) {
|
||||
auto p_slice =
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
@@ -907,15 +1075,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
get_slice_tile(p_scale,
|
||||
sequence<0, i_k1*(kK1 / kVScaleGranularity)>{},
|
||||
sequence<kM0, (i_k1 + 1) * (kK1 / kVScaleGranularity)>{});
|
||||
gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
gemm_1(o_acc0, p_slice, v_lds_window);
|
||||
run_gemm_1_impl(
|
||||
o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc, p_slice, v_lds_window);
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
run_gemm_1_impl(o_acc0, p_slice, v_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_gemm_1_impl(o_acc, p_slice, v_lds_window);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1040,6 +1212,94 @@ struct BlockFmhaPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices,
|
||||
typename QScaleDramBlockWindowTmp,
|
||||
typename KScaleDramBlockWindowTmp,
|
||||
typename VScaleDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const QScaleDramBlockWindowTmp&
|
||||
q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile
|
||||
const KScaleDramBlockWindowTmp&
|
||||
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
|
||||
const VScaleDramBlockWindowTmp&
|
||||
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
q_element_func,
|
||||
k_dram_block_window_tmp,
|
||||
k_element_func,
|
||||
v_dram_block_window_tmp,
|
||||
v_element_func,
|
||||
bias_dram_block_window_tmp,
|
||||
bias_element_func,
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_window_tmp,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
block_scale_size_kv,
|
||||
q_scale_dram_block_window_tmp,
|
||||
k_scale_dram_block_window_tmp,
|
||||
v_scale_dram_block_window_tmp,
|
||||
sink_v,
|
||||
kQKHeaddim / kK0,
|
||||
kK0,
|
||||
kN1);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
@@ -1064,7 +1324,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
const index_t valid_k0_loops,
|
||||
const index_t valid_last_k0_length,
|
||||
const index_t valid_n1_length) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1094,8 +1357,60 @@ struct BlockFmhaPipelineQRKSVS
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_v);
|
||||
sink_v,
|
||||
valid_k0_loops,
|
||||
valid_last_k0_length,
|
||||
valid_n1_length);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
sink_v,
|
||||
kQKHeaddim / kK0,
|
||||
kK0,
|
||||
kN1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -692,9 +692,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false, /* StreamLLM sink tokens */
|
||||
index_t kPageBlockSize_ = 1,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
@@ -70,7 +71,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
QScaleEnum_,
|
||||
kBlockPerCu_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
kHasSink_>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
|
||||
@@ -456,9 +456,6 @@ struct MoeSortingKernel
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
@@ -1206,17 +1203,21 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
else if constexpr(wave_size_ == 8) return 3;
|
||||
else if constexpr(wave_size_ == 16) return 4;
|
||||
else if constexpr(wave_size_ == 32) return 5;
|
||||
else if constexpr(wave_size_ == 64) return 6;
|
||||
else return 0;
|
||||
constexpr int reduce_stage = []() {
|
||||
if constexpr(wave_size_ == 2)
|
||||
return 1;
|
||||
else if constexpr(wave_size_ == 4)
|
||||
return 2;
|
||||
else if constexpr(wave_size_ == 8)
|
||||
return 3;
|
||||
else if constexpr(wave_size_ == 16)
|
||||
return 4;
|
||||
else if constexpr(wave_size_ == 32)
|
||||
return 5;
|
||||
else if constexpr(wave_size_ == 64)
|
||||
return 6;
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
// clang-format on
|
||||
T v_local = local;
|
||||
@@ -3047,53 +3048,6 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
x_r = x_v;
|
||||
#endif
|
||||
{
|
||||
#if 0
|
||||
#pragma unroll
|
||||
for(int j = 0; j < index_pack / 2; j++)
|
||||
{
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
|
||||
index_t x = x_d[j];
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position = cumsum - i_show;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
{
|
||||
d_t i_topk;
|
||||
d_t i_show;
|
||||
@@ -3151,68 +3105,6 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
}
|
||||
position += i_show[j];
|
||||
});
|
||||
|
||||
#if 0
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
|
||||
index_t x = x_d[j];
|
||||
index_t x0 = static_cast<index_t>(x & 0xffff);
|
||||
index_t x1 = static_cast<index_t>(x >> 16);
|
||||
int i_topk_0 = x0 - 1; // topk of this token
|
||||
int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position_0 = cumsum - i_show_0 - i_show_1;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show_0)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_0] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk_0);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_0] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_0] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
|
||||
}
|
||||
|
||||
int position_1 = cumsum - i_show_1;
|
||||
|
||||
if(i_show_1)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_1] =
|
||||
MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_1] = i_token + 1;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_1] =
|
||||
p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
|
||||
// struct MoeSortingPipeline
|
||||
// {
|
||||
// // TODO: this kernel only support warp per row
|
||||
// using Problem = remove_cvref_t<Problem_>;
|
||||
// using Policy = remove_cvref_t<Policy_>;
|
||||
// using WeightType = typename Problem::WeightType;
|
||||
|
||||
// template <typename TopkIdWindow, typename WeightWindow>
|
||||
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
|
||||
// const WeightWindow& weight_window,
|
||||
|
||||
@@ -36,9 +36,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
@@ -19,30 +19,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
|
||||
@@ -21,13 +21,18 @@ struct BlockGemmARegBSmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr bool kSupportsPartialK = true;
|
||||
static constexpr bool kSupportsPartialN = true;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
template <bool UsePartialN,
|
||||
typename CBlockTensor,
|
||||
typename ABlockTensorTmp,
|
||||
typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void Impl(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
[[maybe_unused]] const index_t valid_n_iters) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
@@ -134,10 +139,7 @@ struct BlockGemmARegBSmemCRegV2
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
auto run_n_iter = [&](auto kIter, auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
@@ -166,7 +168,44 @@ struct BlockGemmARegBSmemCRegV2
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// hot loop:
|
||||
if constexpr(UsePartialN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if(static_cast<index_t>(nIter.value) < valid_n_iters)
|
||||
{
|
||||
run_n_iter(kIter, nIter);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { run_n_iter(kIter, nIter); });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B (executing only the first valid_n_iters N sub-iterations)
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const index_t valid_n_iters) const
|
||||
{
|
||||
Impl<true>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
|
||||
}
|
||||
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
Impl<false>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, 0);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
@@ -227,7 +266,17 @@ struct BlockGemmARegBSmemCRegV2
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
// C = A * B (executing only the first valid_n_iters N sub-iterations)
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const index_t valid_n_iters) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
|
||||
@@ -16,30 +16,7 @@ struct BlockGemmARegBSmemCRegV2DefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -19,30 +19,7 @@ struct BlockGemmASmemBRegCRegV1DefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
|
||||
@@ -5,52 +5,9 @@
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
* elements into lower part of a_vec to half its effective size.
|
||||
* @param a_vec Vector to be compressed.
|
||||
* @tparam ADataType The data type of a_vec
|
||||
* @tparam CompressedSize The target compression size
|
||||
* @tparam AVec The vector type of a_vec (deduced)
|
||||
* @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields.
|
||||
* Each field encodes the original position (0–3) of the corresponding
|
||||
* non‑zero element in the input. If fewer than CompressedSize
|
||||
* non‑zeros are found, remaining fields default to 2 (see below).
|
||||
*/
|
||||
template <typename ADataType, index_t CompressedSize, typename AVec>
|
||||
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
|
||||
{
|
||||
// idx holds one 2‑bit index per output element (total CompressedSize entries).
|
||||
// It is initialized to the pattern 0b10 for every field. This matches
|
||||
// what the hardware expects when there are fewer than two non‑zero values
|
||||
// in a 4‑element group – the unused output is treated as coming from slot 2.
|
||||
// The loop below will clear and set each field as real non‑zeros are seen.
|
||||
int32_t idx = 0;
|
||||
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
|
||||
|
||||
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
// clear the two‑bit field for this output and insert j
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmSmfmacImpl
|
||||
{
|
||||
@@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
template <index_t CompressedSize, typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
/// elements into lower part of a_vec to half its effective size.
|
||||
///
|
||||
/// @param a_vec Vector to be compressed.
|
||||
///
|
||||
/// @return Four 2-bit indexes of non-zero elements locations
|
||||
///
|
||||
template <typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
|
||||
{
|
||||
return compress_a_impl<ADataType, CompressedSize>(a_vec);
|
||||
int32_t idx = 0b11101110;
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
@@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
static constexpr index_t CompressedSize =
|
||||
ATensor::get_thread_buffer_size() / CompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
using AVecCompressed =
|
||||
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
|
||||
const int32_t idx = compress_a(a_vec);
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
|
||||
@@ -120,10 +120,6 @@ struct BlockNormReduceSync
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
@@ -360,17 +356,6 @@ struct BlockNormReduceCrossWarpSync
|
||||
template <typename BlockShape>
|
||||
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
|
||||
{
|
||||
#if 0
|
||||
using S = BlockShape;
|
||||
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
|
||||
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
|
||||
index_t iNLane = get_thread_id() % NThread;
|
||||
index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
|
||||
index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
|
||||
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
|
||||
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
|
||||
return iN0 * S::Vector_N + iN3;
|
||||
#endif
|
||||
using S_ = BlockShape;
|
||||
constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
|
||||
|
||||
|
||||
@@ -140,28 +140,6 @@ struct BlockReduce2d
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
// FIXME: hard coded to reduce 2nd axis
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto y = y_tensor[y_dstr_idx];
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
y = reduce_func(y, x);
|
||||
});
|
||||
|
||||
y_tensor(y_dstr_idx) = y;
|
||||
});
|
||||
#endif
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
{
|
||||
@@ -240,10 +218,6 @@ struct BlockReduce2dSync
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// loop over thread data
|
||||
|
||||
45
include/ck_tile/utility/tile_load_store_microkernels.hpp
Normal file
45
include/ck_tile/utility/tile_load_store_microkernels.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file tile_load_store_microkernels.hpp
|
||||
* @brief Generic tile store/load microkernels.
|
||||
*
|
||||
* Setup::create() must return:
|
||||
* - For StoreTile: tuple<window, tile>
|
||||
* - For LoadTile: window
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Setup>
|
||||
struct StoreTile
|
||||
{
|
||||
static constexpr index_t kBlockSize = Setup::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE void operator()() const
|
||||
{
|
||||
auto [window, tile] = Setup::create();
|
||||
store_tile(window, tile);
|
||||
block_sync_lds();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Setup>
|
||||
struct LoadTile
|
||||
{
|
||||
static constexpr index_t kBlockSize = Setup::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE void operator()() const
|
||||
{
|
||||
auto window = Setup::create();
|
||||
[[maybe_unused]] volatile auto tile = load_tile(window);
|
||||
block_sync_lds();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user