mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] Adding WMMA wrappers for dense builtins (#5801)
## Motivation This PR is part of the [WMMA/MFMA] unification work. It's the first of the series of PRs that add all the necessary MMA builtins as a `amdgcn_mma` structs. ## Technical Details This change adds new specializations for WMMA dense builtins. In total, we have now 9 RDNA4 builtins and 3 RDNA3 builtins. ## Test Plan All the new wrappers were added to the test suite in `test_amdgcn_mma_layout.inc`. ## Test Result Test pass locally, waiting for the CI. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Yung-sheng Tu <yung-sheng@streamhpc.com>
This commit is contained in:
committed by
GitHub
parent
ffba5aefcc
commit
0ebeb88ba9
@@ -77,8 +77,8 @@ using Intrinsics = ck_tile::tuple<
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -79,8 +79,4 @@ concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
};
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
struct DefaultSparseWmmaCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -41,7 +42,7 @@ struct SparseWmmaDefaultSelector
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultSparseWmmaCtrlFlags,
|
||||
DefaultWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
|
||||
@@ -3,29 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum WmmaCtrlFlags
|
||||
* @brief Common wmma control flags for gfx11 and gfx12
|
||||
*/
|
||||
enum struct WmmaCtrlFlags : bool
|
||||
{
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit
|
||||
// Determines which half of the 32-bit accum register to use
|
||||
// Low = bits [15:0]
|
||||
// High = bits[31:16]
|
||||
LOW = false,
|
||||
HIGH = true,
|
||||
|
||||
// Only has an effect on gfx11 / 12 when the input is 8-bit int
|
||||
// Signage indicator of inputs / accum
|
||||
UNSIGNED = false,
|
||||
SIGNED = true
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the architecture-specific WMMA implementations and traits
|
||||
#include "wmma_gfx11.hpp"
|
||||
#include "wmma_gfx12.hpp"
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
|
||||
@@ -36,30 +36,6 @@ namespace ck_tile::core::arch::mma {
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @class DefaultWmmaFlags
|
||||
* @brief Generates default WMMA control flags based on data types.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename CDataType>
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
// Generate default flags for signage
|
||||
// Only used currently for integer inputs / accum in gfx11 / gfx12
|
||||
constexpr static WmmaCtrlFlags InputSignA =
|
||||
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags InputSignB =
|
||||
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags AccumSign =
|
||||
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
|
||||
// Generate default flags for accumulator destination bits.
|
||||
// Only used if accumulation size is 16-bit in gfx11
|
||||
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
|
||||
@@ -76,11 +52,62 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness
|
||||
bit_cast<int32x4_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x4_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -22,7 +22,7 @@ namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
@@ -32,15 +32,208 @@ namespace ck_tile::core::arch::mma {
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness
|
||||
bit_cast<int32x2_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x2_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -36,7 +36,7 @@ struct WmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags;
|
||||
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
@@ -93,7 +93,7 @@ struct WmmaDefaultSelector<ADataType,
|
||||
CompilerTarget>
|
||||
{
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags;
|
||||
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
|
||||
@@ -41,4 +41,18 @@ struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @struct DefaultWmmaCtrlFlags
|
||||
* @brief Default WMMA control flags for dense and sparse WMMA operations.
|
||||
*/
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
constexpr static bool Clamp = false;
|
||||
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit.
|
||||
// Determines which half of the 32-bit accum register to use for the 16-bit result.
|
||||
// false = low bits [15:0], true = high bits [31:16]
|
||||
constexpr static bool UseHighAccumBits = true;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -3,19 +3,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
// #include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
@@ -23,9 +20,9 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
@@ -40,7 +37,12 @@ using namespace mma;
|
||||
using F8 = fp8_t;
|
||||
using BF8 = bf8_t;
|
||||
using F16 = fp16_t;
|
||||
using BF16 = bf16_t;
|
||||
using F32 = fp32_t;
|
||||
using I8 = int8_t;
|
||||
using FP8 = fp8_t;
|
||||
using BF8 = bf8_t;
|
||||
using I32 = int32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
using Target942 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>());
|
||||
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
@@ -255,11 +257,21 @@ using Gfx950Intrinsics = ::testing::Types<
|
||||
// amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
|
||||
>;
|
||||
using Gfx11Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32,
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32,
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu8_w32
|
||||
>;
|
||||
using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12,
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12,
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12,
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12,
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12,
|
||||
amdgcn_mma<FP8, FP8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12,
|
||||
amdgcn_mma<FP8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12,
|
||||
amdgcn_mma<BF8, FP8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12,
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user