[CK Tile] Adding WMMA wrappers for dense builtins (#5801)

## Motivation

This PR is part of the [WMMA/MFMA] unification work. It's the first of
the series of PRs that add all the necessary MMA builtins as a
`amdgcn_mma` structs.

## Technical Details

This change adds new specializations for WMMA dense builtins. In total,
we have now 9 RDNA4 builtins and 3 RDNA3 builtins.

## Test Plan

All the new wrappers were added to the test suite in
`test_amdgcn_mma_layout.inc`.

## Test Result

Test pass locally, waiting for the CI.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Yung-sheng Tu <yung-sheng@streamhpc.com>
This commit is contained in:
Wojciech Laskowski
2026-04-27 13:57:51 +02:00
committed by GitHub
parent 26ff0da492
commit a581a451f1
9 changed files with 296 additions and 76 deletions

View File

@@ -77,8 +77,8 @@ using Intrinsics = ck_tile::tuple<
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12
>;
// clang-format on

View File

@@ -79,8 +79,4 @@ concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
struct DefaultSparseWmmaCtrlFlags
{
};
} // namespace ck_tile::core::arch::mma

View File

@@ -7,6 +7,7 @@
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -41,7 +42,7 @@ struct SparseWmmaDefaultSelector
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultSparseWmmaCtrlFlags,
DefaultWmmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>;

View File

@@ -3,29 +3,6 @@
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @enum WmmaCtrlFlags
* @brief Common wmma control flags for gfx11 and gfx12
*/
enum struct WmmaCtrlFlags : bool
{
// Only has an effect on gfx11 when the accumulator is 16-bit
// Determines which half of the 32-bit accum register to use
// Low = bits [15:0]
// High = bits[31:16]
LOW = false,
HIGH = true,
// Only has an effect on gfx11 / 12 when the input is 8-bit int
// Signage indicator of inputs / accum
UNSIGNED = false,
SIGNED = true
};
} // namespace ck_tile::core::arch::mma
// Include the architecture-specific WMMA implementations and traits
#include "wmma_gfx11.hpp"
#include "wmma_gfx12.hpp"

View File

@@ -8,8 +8,8 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
namespace ck_tile::core::arch::mma {
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
@@ -36,30 +36,6 @@ namespace ck_tile::core::arch::mma {
// For flexibility, it is recommended that for each backend wrapper it supports at least
// one packed register for each input to be able to process smaller K values by padding.
/**
* @class DefaultWmmaFlags
* @brief Generates default WMMA control flags based on data types.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
*/
template <typename ADataType, typename BDataType, typename CDataType>
struct DefaultWmmaCtrlFlags
{
// Generate default flags for signage
// Only used currently for integer inputs / accum in gfx11 / gfx12
constexpr static WmmaCtrlFlags InputSignA =
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
constexpr static WmmaCtrlFlags InputSignB =
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
constexpr static WmmaCtrlFlags AccumSign =
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
// Generate default flags for accumulator destination bits.
// Only used if accumulation size is 16-bit in gfx11
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
@@ -76,11 +52,62 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness
bit_cast<int32x4_t>(aVec),
true, // B signedness
bit_cast<int32x4_t>(bVec),
cVec,
CtrlFlags::Clamp)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -8,8 +8,8 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
namespace ck_tile::core::arch::mma {
@@ -22,7 +22,7 @@ namespace ck_tile::core::arch::mma {
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
@@ -32,15 +32,208 @@ namespace ck_tile::core::arch::mma {
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness
bit_cast<int32x2_t>(aVec),
true, // B signedness
bit_cast<int32x2_t>(bVec),
cVec,
CtrlFlags::Clamp)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx12<CompilerTarget>()>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
bit_cast<int32x2_t>(aVec), bit_cast<int32x2_t>(bVec), cVec)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -36,7 +36,7 @@ struct WmmaDefaultSelector
{
private:
// By default, let's assume no special flags for WMMA
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
using CtrlFlags = DefaultWmmaCtrlFlags;
// Define our candidate WMMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
@@ -93,7 +93,7 @@ struct WmmaDefaultSelector<ADataType,
CompilerTarget>
{
// By default, let's assume no special flags for WMMA
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
using CtrlFlags = DefaultWmmaCtrlFlags;
// Default unsupported pass-through if no instruction is found
using SelectedOp = amdgcn_mma<ADataType,

View File

@@ -41,4 +41,18 @@ struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
template <typename MmaOp>
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
/**
* @struct DefaultWmmaCtrlFlags
* @brief Default WMMA control flags for dense and sparse WMMA operations.
*/
struct DefaultWmmaCtrlFlags
{
constexpr static bool Clamp = false;
// Only has an effect on gfx11 when the accumulator is 16-bit.
// Determines which half of the 32-bit accum register to use for the 16-bit result.
// false = low bits [15:0], true = high bits [31:16]
constexpr static bool UseHighAccumBits = true;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -3,19 +3,16 @@
#pragma once
#include <hip/hip_runtime.h>
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
// #include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/device_memory.hpp"
@@ -23,9 +20,9 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_config.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
#include <cstdint>
#include <utility>
#include <vector>
#include <cmath>
#include <cstdint>
#include <vector>
@@ -40,7 +37,12 @@ using namespace mma;
using F8 = fp8_t;
using BF8 = bf8_t;
using F16 = fp16_t;
using BF16 = bf16_t;
using F32 = fp32_t;
using I8 = int8_t;
using FP8 = fp8_t;
using BF8 = bf8_t;
using I32 = int32_t;
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
using Target942 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>());
using Target950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
@@ -255,11 +257,21 @@ using Gfx950Intrinsics = ::testing::Types<
// amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
>;
using Gfx11Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32,
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32,
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu8_w32
>;
using Gfx12Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12,
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12,
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12,
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12,
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12,
amdgcn_mma<FP8, FP8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12,
amdgcn_mma<FP8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12,
amdgcn_mma<BF8, FP8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12,
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
>;
// clang-format on