[rocm-libraries] ROCm/rocm-libraries#8227 (commit 75c30d5)

=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?=
 =?UTF-8?q?Remove=20unification=20Flag=20structs=20in=20favor=20of=20new?=
 =?UTF-8?q?=20WarpGemmParams=20(#8227)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Recently, the way flags are sent down to the intrinsics was changed in
CK Tile. At the point where the WarpGemm is invoked, an arbitrary number
of template parameters can be passed, and these are passed down all the
way to the lowest level intrinsics wrappers. Here
`WarpGemmParamsParser<>` is used to extract flags for the intrinsics.

In this MR we adapt the the unification framework (amdgcn_mma struct and
MmaPipelines) to work in the same way. By doing this, there is no longer
a point in our custom intrinsic Flag structs, so these are removed.

Unrelated but I also tried removing the MmaPipeline flags because they
arn't used for anything except CTranspose, which is already available.
This also make test_amdgcn_mma_pipeline completely redundant so removed
that as well.
This commit is contained in:
Kiefer van Teutem
2026-06-26 12:00:58 +00:00
committed by assistant-librarian[bot]
parent 621697af8c
commit 2089713f94
33 changed files with 1059 additions and 1643 deletions

View File

@@ -73,17 +73,17 @@ int check_tile_distr_enc()
// List of intrinsics to test.
// clang-format off
using Intrinsics = ck_tile::tuple<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE> // wmma_i32_16x16x32_iu4_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::DENSE> // wmma_i32_16x16x32_iu4_w32_gfx12
>;
// clang-format on

View File

@@ -43,7 +43,6 @@
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"

View File

@@ -246,25 +246,14 @@ CK_TILE_HOST_DEVICE constexpr const char* to_string(Unsupported) { return "Unsup
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept HasExecSignature
* @brief Helper concept for exec signature check.
*/
template <typename MmaOp, typename... ExecArgs>
concept HasExecSignature = requires {
{
MmaOp::exec(typename MmaOp::AVecType{},
typename MmaOp::BVecType{},
typename MmaOp::CVecType{},
std::declval<ExecArgs>()...)
} -> std::convertible_to<typename MmaOp::CVecType>;
};
/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
*/
// TODO: Make sure this actually matches amdgcn_mma.
// NOTE: It is no longer possible to perform a check on the exec() function, since it is now
// templated over the variadic WarpGemmParams template pack for intrinsic flags. It seems like
// concepts do not work for templated device functions.
template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
@@ -287,7 +276,7 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -305,7 +294,6 @@ concept MmaOpI = requires(MmaOp op) {
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
* @tparam CtrlFlags Control flags for mma operation
* @tparam CompilerTarget The current compiler target
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
* @tparam Enabler SFINAE enabler
@@ -316,7 +304,6 @@ template <typename ADataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CtrlFlags,
typename CompilerTarget,
MmaOpFamily OpFamily_,
typename Enabler = void>
@@ -326,6 +313,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
// clang-format on
{
// This is a default pass-through implementation that doesn't do anything practical.
template <typename... Params>
CK_TILE_DEVICE static auto
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
@@ -347,7 +335,6 @@ template <typename ADataType,
std::uint32_t FragM,
std::uint32_t FragN,
std::uint32_t FragK,
typename CtrlFlags,
typename CompilerTarget,
MmaOpFamily OpFamily_,
typename Enabler = void>
@@ -357,7 +344,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
FragM,
FragN,
FragK,
CtrlFlags,
CompilerTarget,
OpFamily_,
Enabler> const& mmaObj)
@@ -392,10 +378,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
printf(" kCNBlocks : %d\n", mmaObj.kCNBlocks);
printf(" CBlockDimInVecDim : %d\n", mmaObj.CBlockDimInVecDim);
printf("Instruction name : %s\n", ObjType::instruction_name);
if constexpr(!std::is_same_v<CtrlFlags, void>)
{
print_flags(CtrlFlags{});
}
print(CompilerTarget{});
}

File diff suppressed because it is too large Load Diff

View File

@@ -53,7 +53,6 @@ struct MmaDefaultSelector<ADataType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>::SelectedOp;
};

View File

@@ -3,16 +3,8 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include <cinttypes>
#include <stdio.h>
#include <type_traits>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
namespace ck_tile::core::arch::mma {
@@ -56,42 +48,4 @@ struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
template <typename MmaOp>
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
/**
* @struct DefaultMfmaCtrlFlags
* @brief Default MFMA flags, no broadcasting or rotation of inputs
* @note For f64 MFMA instructions, CBSZ and ABID are ignored and BLGP is repurposed for matrix
* negation. BLGP bits [0:2] negate the A, B, and C input matrices respectively (ref. ISA docs for
* MI300 Instinct).
*/
struct DefaultMfmaCtrlFlags
{
static constexpr int32_t Cbsz = 0; // CBSZ flag, default 0
static constexpr int32_t Abid = 0; // ABID flag, default 0
static constexpr int32_t Blgp = 0; // BLGP flag, default 0
};
CK_TILE_HOST_DEVICE void print_flags(DefaultMfmaCtrlFlags const& ctrlFlags)
{
printf("CtrlFlags Cbsz / Abid / Blgp : %" PRId32 " / %" PRId32 " / %" PRId32 "\n",
ctrlFlags.Cbsz,
ctrlFlags.Abid,
ctrlFlags.Blgp);
}
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept CtrlFlagsGfx9I
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
*/
template <typename CtrlFlags>
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
// Flag members for Gfx9 MFMA instructions
{ CtrlFlags::Cbsz } -> std::convertible_to<int32_t>;
{ CtrlFlags::Abid } -> std::convertible_to<int32_t>;
{ CtrlFlags::Blgp } -> std::convertible_to<int32_t>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -16,88 +16,11 @@
#endif
namespace ck_tile::core::arch::mma {
/*! @enum MmaPipelineOptionFlag
* @brief Individual option flags for configuring MmaPipeline behavior.
*/
enum struct MmaPipelineOptionFlag : unsigned
{
NONE = 0x0, ///< No flags set
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
};
/**
* @struct MmaPipelineOptionFlags
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
* and querying pipeline option flags.
*/
struct MmaPipelineOptionFlags
{
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
{
}
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
: mFlags(original.mFlags)
{
}
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
{
mFlags |= toType(addValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
{
MmaPipelineOptionFlags result(*this);
result |= addValue;
return result;
}
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
{
mFlags &= toType(maskValue);
return *this;
}
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
{
MmaPipelineOptionFlags result(*this);
result &= maskValue;
return result;
}
constexpr MmaPipelineOptionFlags operator~() const
{
MmaPipelineOptionFlags result(*this);
result.mFlags = ~result.mFlags;
return result;
}
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
{
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
}
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
private:
Type mFlags;
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
};
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
{
return rhs == lhs;
}
/**
* @class MmaPipelineBase
* @brief CRTP base class that implements the common Mma pipeline logic shared by
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
*
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
* pipeline behavior (e.g., C transposition, A compression).
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
* - Type aliases: @c AWarpTensor, @c BWarpTensor, @c CWarpTensor, @c MmaOp
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform, @c DTransform
@@ -107,14 +30,11 @@ constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOpt
* 1. Apply pre-transforms to input buffers (A, B, C).
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
* 3. Apply post-transform to output buffer (D).
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
* When CTranspose is used, the A and B inputs are swapped before step 1.
*/
// TODO: c++20: use MmaPipelineOptionFlags directly
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
template <typename Derived>
struct MmaPipelineBase
{
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
/**
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
* @tparam ATensor Type of the A WaveTile tensor (static_distributed_tensor).
@@ -125,17 +45,17 @@ struct MmaPipelineBase
* @param accum Input/output accumulator WaveTile C.
* @return The output WaveTile D after accumulation and post-transform.
*/
template <typename ATensor, typename BTensor, typename CTensor>
template <typename... Params, typename ATensor, typename BTensor, typename CTensor>
CK_TILE_DEVICE static decltype(auto) exec(ATensor& a, BTensor& b, CTensor& accum)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
if constexpr(Flags & MmaPipelineOptionFlag::ABSwap)
if constexpr(Derived::CTranspose)
{
decltype(auto) a_transformed = Derived::ATransform::exec(b);
decltype(auto) b_transformed = Derived::BTransform::exec(a);
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
Derived::execImpl(a_transformed, b_transformed, c_transformed);
Derived::template execImpl<Params...>(a_transformed, b_transformed, c_transformed);
return Derived::DTransform::exec(c_transformed);
}
else
@@ -143,7 +63,7 @@ struct MmaPipelineBase
decltype(auto) a_transformed = Derived::ATransform::exec(a);
decltype(auto) b_transformed = Derived::BTransform::exec(b);
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
Derived::execImpl(a_transformed, b_transformed, c_transformed);
Derived::template execImpl<Params...>(a_transformed, b_transformed, c_transformed);
return Derived::DTransform::exec(c_transformed);
}
}
@@ -153,7 +73,7 @@ struct MmaPipelineBase
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
return Derived::MmaOp::template exec<Params...>({}, {}, {});
}
}
@@ -162,11 +82,10 @@ struct MmaPipelineBase
template <typename... Params, typename CTensor, typename ATensor, typename BTensor>
CK_TILE_DEVICE void operator()(CTensor& c, ATensor& a, const BTensor& b) const
{
exec(a, b, c);
exec<Params...>(a, b, c);
}
template <index_t opselA,
index_t opselB,
template <typename... Params,
typename ATensor,
typename BTensor,
typename CTensor,
@@ -180,7 +99,7 @@ struct MmaPipelineBase
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
if constexpr(Flags & MmaPipelineOptionFlag::ABSwap)
if constexpr(Derived::CTranspose)
{
// TODO: Figure out which combination of a/b, scale_A/B, and opselA/B needs to be
// AB-swapped in order to get correct results. Note that WarpGemmParamsParser
@@ -188,7 +107,7 @@ struct MmaPipelineBase
decltype(auto) a_transformed = Derived::ATransform::exec(b);
decltype(auto) b_transformed = Derived::BTransform::exec(a);
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
Derived::template execImpl<opselA, opselB>(
Derived::template execImpl<Params...>(
a_transformed, b_transformed, c_transformed, scale_A, scale_B);
return Derived::DTransform::exec(c_transformed);
}
@@ -197,7 +116,7 @@ struct MmaPipelineBase
decltype(auto) a_transformed = Derived::ATransform::exec(a);
decltype(auto) b_transformed = Derived::BTransform::exec(b);
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
Derived::template execImpl<opselA, opselB>(
Derived::template execImpl<Params...>(
a_transformed, b_transformed, c_transformed, scale_A, scale_B);
return Derived::DTransform::exec(c_transformed);
}
@@ -219,8 +138,7 @@ struct MmaPipelineBase
const int32_t& a_scale,
const int32_t& b_scale) const
{
using P = WarpGemmParamsParser<Params...>;
exec<P::op_sel_a, P::op_sel_b>(a, b, c, a_scale, b_scale);
exec<Params...>(a, b, c, a_scale, b_scale);
}
};
@@ -232,8 +150,8 @@ struct MmaPipelineBase
* @concept MmaPipelineI
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
*/
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
template <typename Derived>
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Derived>>;
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER

View File

@@ -49,7 +49,6 @@ struct MmaDefaultSelector
WaveTileM,
WaveTileN,
WaveTileK,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>;
};
@@ -88,7 +87,6 @@ template <typename ADataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileKTest,
typename CtrlFlags,
typename CompilerTarget, // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
MmaOpFamily OpFamily>
struct MmaKSearchSelector
@@ -102,7 +100,6 @@ struct MmaKSearchSelector
WaveTileM,
WaveTileN,
WaveTileKTest,
CtrlFlags,
CompilerTarget,
OpFamily>;
@@ -118,7 +115,6 @@ struct MmaKSearchSelector
WaveTileM,
WaveTileN,
WaveTileKTest / 2u,
CtrlFlags,
CompilerTarget,
OpFamily>::SelectedOp>;
};
@@ -128,7 +124,6 @@ template <typename ADataType,
typename CDataType,
uint32_t WaveTileM,
uint32_t WaveTileN,
typename CtrlFlags,
typename CompilerTarget, // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
MmaOpFamily OpFamily>
struct MmaKSearchSelector<ADataType,
@@ -137,20 +132,12 @@ struct MmaKSearchSelector<ADataType,
WaveTileM,
WaveTileN,
0u,
CtrlFlags,
CompilerTarget,
OpFamily>
{
// Recursion endpoint: unsupported default implementation.
using SelectedOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
1u,
1u,
1u,
CtrlFlags,
CompilerTarget,
OpFamily>;
using SelectedOp =
amdgcn_mma<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget, OpFamily>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -6,11 +6,8 @@
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/config.hpp"
#include "mfma/mfma_traits.hpp"
#include "scale/scale_traits.hpp"
#include "sparse/sparse_traits.hpp"
#include "wmma/wmma_traits.hpp"
#include <cstdint>
#include <stdio.h>
#include <type_traits>
@@ -61,7 +58,6 @@ struct MmaOpTraits;
* @tparam FragM_ Size of the M dimension
* @tparam FragN_ Size of the N dimension
* @tparam FragK_ Size of the K dimension
* @tparam CtrlFlags_ Control flags for the MMA operation
* @tparam CompilerTarget_ The compiler target
*/
template <typename ADataType_,
@@ -70,7 +66,6 @@ template <typename ADataType_,
uint32_t FragM_,
uint32_t FragN_,
uint32_t FragK_,
typename CtrlFlags_,
typename CompilerTarget_,
MmaOpFamily OpFamily_>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
@@ -80,7 +75,6 @@ struct MmaOpTraits<amdgcn_mma<ADataType_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>>
{
@@ -90,12 +84,10 @@ struct MmaOpTraits<amdgcn_mma<ADataType_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>;
// Capture incoming template parameters not already in amdgcn
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
@@ -115,7 +107,6 @@ template <typename ADataType_,
uint32_t FragM_,
uint32_t FragN_,
uint32_t FragK_,
typename CtrlFlags_,
typename CompilerTarget_,
MmaOpFamily OpFamily_>
CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
@@ -124,7 +115,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>> const& traitsObj)
{
@@ -134,7 +124,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
FragM_,
FragN_,
FragK_,
CtrlFlags_,
CompilerTarget_,
OpFamily_>{});
printf(

View File

@@ -28,15 +28,6 @@ enum struct MmaAccumPolicy
COL_MAJOR
};
namespace dense::wavewise::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool SwapAB>
constexpr inline int getPipelineFlags()
{
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
}
} // namespace dense::wavewise::detail
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
@@ -50,7 +41,7 @@ constexpr inline int getPipelineFlags()
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
* @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation.
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
@@ -72,7 +63,7 @@ template <typename ADataType_,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool CTranspose = false,
bool CTranspose_ = false,
index_t SwizzleFactor = 1,
index_t AttrNumAccessAV = 1,
index_t AttrNumAccessBV = AttrNumAccessAV,
@@ -92,11 +83,12 @@ template <typename ADataType_,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<CTranspose>(), WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
struct WaveWiseMmaPipeline : public MmaPipelineBase<WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<CTranspose>(), WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_;
using MmaOp = MmaOp_;
static constexpr bool CTranspose = CTranspose_;
using ADataType = typename MmaOp::ADataType;
using BDataType = typename MmaOp::BDataType;
@@ -185,7 +177,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK");
// TODO: Why does this even need to be a template? The types should be known.
template <typename ATensor, typename BTensor, typename CTensor>
template <typename... Params, typename ATensor, typename BTensor, typename CTensor>
CK_TILE_DEVICE static void execImpl(ATensor& a, BTensor& b, CTensor& c)
{
static_assert(
@@ -205,9 +197,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) = MmaOp::exec(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn));
c_buf.at(bm * FragsN + bn) =
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn));
}
}
}
@@ -220,9 +213,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) = MmaOp::exec(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn));
c_buf.at(bm * FragsN + bn) =
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn));
}
}
}

View File

@@ -13,6 +13,7 @@
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
namespace ck_tile::core::arch::mma {
@@ -23,14 +24,13 @@ namespace ck_tile::core::arch::mma {
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -38,19 +38,20 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -62,14 +63,13 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -77,19 +77,20 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -101,14 +102,13 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -116,10 +116,11 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
int32x4_t arg_a = bit_cast<int32x4_t>(aVec);
int32x4_t arg_b = bit_cast<int32x4_t>(bVec);
@@ -129,9 +130,9 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -143,33 +144,33 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
* This specialization implements the Scale MFMA instruction for pk_fp6x16_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -182,33 +183,33 @@ struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, C
* This specialization implements the Scale MFMA instruction for pk_bf6x16_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -221,14 +222,13 @@ struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, C
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -236,19 +236,20 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -260,14 +261,13 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -275,19 +275,20 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -299,14 +300,13 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
// clang-format on
@@ -314,10 +314,11 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
int32x4_t arg_a = bit_cast<int32x4_t>(aVec);
int32x4_t arg_b = bit_cast<int32x4_t>(bVec);
@@ -327,9 +328,9 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -341,33 +342,33 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
* This specialization implements the Scale MFMA instruction for pk_fp6x16_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};
@@ -380,33 +381,33 @@ struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Co
* This specialization implements the Scale MFMA instruction for pk_bf6x16_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
template <index_t opselA, index_t opselB>
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
opselA,
P::op_sel_a,
scale_A,
opselB,
P::op_sel_b,
scale_B)};
}
};

View File

@@ -55,7 +55,6 @@ struct MmaDefaultSelector<ADataType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultScaleMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SCALE>;
};

View File

@@ -32,7 +32,7 @@ namespace ck_tile::core::arch::mma {
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
* @tparam SwizzleFactor Swizzlefactor for Tile Distribution Encoding calculation.
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
@@ -47,7 +47,7 @@ template <typename ADataType_,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool CTranspose = false,
bool CTranspose_ = false,
index_t SwizzleFactor = 1,
index_t AttrNumAccessAV = 1,
index_t AttrNumAccessBV = AttrNumAccessAV,
@@ -67,12 +67,13 @@ template <typename ADataType_,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
struct ScaleMmaPipeline : public MmaPipelineBase<ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_; // Expose the selected MmaOp
using MmaOp = MmaOp_; // Expose the selected MmaOp
static constexpr bool CTranspose = CTranspose_;
using ADataType = typename MmaOp::ADataType;
using BDataType = typename MmaOp::BDataType;
@@ -170,8 +171,7 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK");
// TODO: Why does this even need to be a template? The types should be known.
template <index_t opselA,
index_t opselB,
template <typename... Params,
typename ATensor,
typename BTensor,
typename CTensor,
@@ -198,11 +198,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) =
MmaOp::template exec<opselA, opselB>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),
scale_A,
scale_B);
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),
scale_A,
scale_B);
}
}
}
@@ -216,11 +216,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) =
MmaOp::template exec<opselA, opselB>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),
scale_A,
scale_B);
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),
scale_A,
scale_B);
}
}
}

View File

@@ -3,85 +3,32 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_f6.hpp"
#include <cstdint>
#include <stdio.h>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
namespace ck_tile::core::arch::mma {
namespace scale::detail {
// Utility for converting the datatype of the A or B input matrix in a scale intrinsics to the
// appropriate datatype flag. Note that this is not the same as the flag indicating the scale
// datatype, see ScaleDataTypeToEnum.
template <typename T>
struct ScaleDataTypeToFlag;
inline constexpr int32_t ScaleDataTypeToFlag_v = [] {
// sizeof(T) trick to only trigger the static assert for unsupported datatypes.
static_assert(sizeof(T) == 0, "Unsupported scale data type");
return -1;
}();
template <>
struct ScaleDataTypeToFlag<fp8_t> // e4m3 (4 exponent bits 3 mantissa bits)
{
static constexpr int32_t value = 0;
};
inline constexpr int32_t ScaleDataTypeToFlag_v<fp8_t> = 0; // e4m3
template <>
struct ScaleDataTypeToFlag<bf8_t> // e5m2
{
static constexpr int32_t value = 1;
};
inline constexpr int32_t ScaleDataTypeToFlag_v<bf8_t> = 1; // e5m2
template <>
struct ScaleDataTypeToFlag<pk_fp6x16_t> // e2m3
{
static constexpr int32_t value = 2;
};
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_fp6x16_t> = 2; // e2m3
template <>
struct ScaleDataTypeToFlag<pk_bf6x16_t> // e3m2
{
static constexpr int32_t value = 3;
};
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_bf6x16_t> = 3; // e3m2
template <>
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
{
static constexpr int32_t value = 4;
};
template <typename T>
inline constexpr int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept ScaleMfmaDataTypeToFlag
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
*/
template <typename DataTypeToFlag>
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
// Flag members for scale MFMA instructions
{ DataTypeToFlag::value } -> std::convertible_to<int32_t>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_fp4_t> = 4; // e2m1
} // namespace scale::detail
// No real flags for now, scale and opsel are handled in higher level and passed down directly.
// OPSEL is now passed as a template arg to exec(), see mma_pipeline.hpp
// We will soon get rid of these flags entirely in favor of variadic template packs passed down to
// the intrinsics directly, see WarpGemmParamsParser<>.
struct DefaultScaleMfmaCtrlFlags
{
};
CK_TILE_HOST_DEVICE void print_flags([[maybe_unused]] DefaultScaleMfmaCtrlFlags const& ctrlFlags)
{
printf("CtrlFlags: (empty)\n");
}
} // namespace ck_tile::core::arch::mma

View File

@@ -7,7 +7,6 @@
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
namespace ck_tile::core::arch::mma {
@@ -55,7 +54,6 @@ struct MmaDefaultSelector<ADataType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultSparseMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp;
};

View File

@@ -7,11 +7,11 @@
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
#include <type_traits>
@@ -24,55 +24,55 @@ namespace ck_tile::core::arch::mma {
* This specialization implements the SMFMA instruction for fp16_t A and B
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes.
*
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_f16";
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
}
using P = WarpGemmParamsParser<Params...>;
return __builtin_amdgcn_smfmac_f32_16x16x32_f16(
aVec,
bVec,
cVec,
idx,
P::cbsz, // Ignore abid and use first portion Y/N
P::abid); // Portion of idx VGPR containing idx info
};
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, 64u, 8, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_f16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -80,27 +80,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_bf16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -108,27 +105,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, 64u, 8, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_bf16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -136,27 +130,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX942 and
* GFX950 architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x64_i8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -164,27 +155,24 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX942 and
* GFX950 architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x32_i8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -192,27 +180,25 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -220,27 +206,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -248,27 +232,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -276,27 +258,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -304,27 +284,25 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -332,27 +310,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -360,27 +336,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -388,27 +362,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -416,27 +388,24 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_f16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -444,27 +413,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_f16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -472,27 +438,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -500,27 +463,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf16";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -528,27 +488,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x128_i8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -556,27 +513,24 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CtrlFlags, CompilerTa
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x64_i8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -584,27 +538,25 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -612,27 +564,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -640,27 +590,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -668,27 +616,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -696,27 +642,25 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -724,27 +668,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -752,27 +694,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
@@ -780,27 +720,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX950
* architecture.
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using namespace sparse::detail;
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
using P = WarpGemmParamsParser<Params...>;
return {
__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -11,5 +11,4 @@ namespace ck_tile::core::arch::mma {
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"

View File

@@ -12,16 +12,6 @@
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool SwapAB>
constexpr inline int getPipelineFlags()
{
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A) |
static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
}
} // namespace sparse::detail
/**
* @class SparseMmaPipeline
* @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation
@@ -38,7 +28,7 @@ constexpr inline int getPipelineFlags()
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
* @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation.
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
@@ -53,7 +43,7 @@ template <typename ADataType_,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool CTranspose = false,
bool CTranspose_ = false,
index_t SwizzleFactor = 1,
index_t AttrNumAccessAV = 1,
index_t AttrNumAccessBV = AttrNumAccessAV,
@@ -73,11 +63,12 @@ template <typename ADataType_,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags<CTranspose>(), SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
struct SparseMmaPipeline : public MmaPipelineBase<SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags<CTranspose>(), SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_;
using MmaOp = MmaOp_;
static constexpr bool CTranspose = CTranspose_;
using ADataType = typename MmaOp::ADataType;
using BDataType = typename MmaOp::BDataType;
@@ -86,8 +77,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<ADataType, ADataType_>);
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<BDataType, BDataType_>);
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<CDataType, CDataType_>);
static_assert(!(Base::Flags & MmaPipelineOptionFlag::ABSwap),
"Cannot transpose C in sparse intrinsics.");
static_assert(!CTranspose, "Cannot transpose C in sparse intrinsics.");
// WaveTile dimensions (Used to be fragment dims but higher level expects these to include k
// iteration!)
@@ -180,7 +170,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
// ATransformResult is a big ext_vector plus idx, B and C are static_distributed tensors. Fix
// later TODO.
template <typename ATransformResult, typename BTensor, typename CTensor>
template <typename... Params, typename ATransformResult, typename BTensor, typename CTensor>
CK_TILE_DEVICE static void execImpl(ATransformResult& a, BTensor& b, CTensor& c)
{
static_assert(
@@ -206,7 +196,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) = MmaOp::exec(
c_buf.at(bm * FragsN + bn) = MmaOp::template exec<Params...>(
a_frags[bm][bk],
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),
@@ -224,7 +214,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_buf.at(bm * FragsN + bn) = MmaOp::exec(
c_buf.at(bm * FragsN + bn) = MmaOp::template exec<Params...>(
a_frags[bm][bk],
b_buf.at(bn * FragsK + bk),
c_buf.at(bm * FragsN + bn),

View File

@@ -1,106 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include <stdio.h>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif
namespace ck_tile::core::arch::mma {
/**
* @enum SparseCompressionIndex
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
*/
enum struct SparseCompressionIndex : int
{
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
THIRD = 2, // Uses bits [23:16]
FOURTH = 3, // Uses bits [31:24]
};
// to_string methods for enum classes
CK_TILE_HOST_DEVICE constexpr const char* to_string(SparseCompressionIndex compressionIndex)
{
switch(compressionIndex)
{
case SparseCompressionIndex::FIRST: return "FIRST";
case SparseCompressionIndex::SECOND: return "SECOND";
case SparseCompressionIndex::THIRD: return "THIRD";
case SparseCompressionIndex::FOURTH: return "FOURTH";
}
__builtin_unreachable();
}
namespace sparse::detail {
/**
* @struct BuiltinParams
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
* (VGPR[srcC][7..0]).
*
* 8-bit source data:
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
* (VGPR[srcC][15..0]).
*/
struct BuiltinParams
{
int UseFirstIndex; // CBSZ
int ByteIndexToOverride; // ABID
};
template <SparseCompressionIndex Idx>
static constexpr BuiltinParams getBuiltinParams()
{
// TODO c++20: designated initializers
if constexpr(Idx == SparseCompressionIndex::FIRST)
{
return BuiltinParams{1, 0};
}
else
{
return BuiltinParams{0, static_cast<int>(Idx)};
}
}
} // namespace sparse::detail
/**
* @struct DefaultSparseMfmaCtrlFlags
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
*/
struct DefaultSparseMfmaCtrlFlags
{
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
};
CK_TILE_HOST_DEVICE void print_flags(DefaultSparseMfmaCtrlFlags const& ctrlFlags)
{
printf("CtrlFlags CompressionIndex : %s\n", to_string(ctrlFlags.CompressionIndex));
}
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept SparseMfmaCtrlFlags
* @brief Expresses the interface of required members for each CtrlFlags type
*/
template <typename CtrlFlags>
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
// Flag members for sparse MFMA instructions
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -53,7 +53,6 @@ struct MmaDefaultSelector<ADataType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultWmmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp;
};

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
namespace ck_tile::core::arch::mma {
@@ -18,20 +19,20 @@ namespace ck_tile::core::arch::mma {
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -43,20 +44,20 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -68,20 +69,20 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -93,21 +94,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -119,30 +120,31 @@ struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness
aVec,
true, // B signedness
bVec,
cVec,
idx,
CtrlFlags::Clamp)};
P::clamp)};
}
};
@@ -150,21 +152,21 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -176,21 +178,21 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -202,21 +204,21 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -228,21 +230,21 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
@@ -250,51 +252,55 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
}
};
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32";
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32(true, // A signedness
bit_cast<int32_t>(aVec),
true, // B signedness
bit_cast<int32x2_t>(bVec),
cVec,
idx,
CtrlFlags::Clamp)};
P::clamp)};
}
};
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, 32u, 32, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32";
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32(true, // A signedness
bit_cast<int32x2_t>(aVec),
true, // B signedness
bit_cast<int32x4_t>(bVec),
cVec,
idx,
CtrlFlags::Clamp)};
P::clamp)};
}
};

View File

@@ -16,6 +16,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
namespace ck_tile::core::arch::mma {
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
@@ -46,20 +47,20 @@ namespace ck_tile::core::arch::mma {
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -71,20 +72,20 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -96,29 +97,30 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness
bit_cast<int32x4_t>(aVec),
true, // B signedness
bit_cast<int32x4_t>(bVec),
cVec,
CtrlFlags::Clamp)};
P::clamp)};
}
};
@@ -126,29 +128,30 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32(true, // A signedness
bit_cast<int32x2_t>(aVec),
true, // B signedness
bit_cast<int32x2_t>(bVec),
cVec,
CtrlFlags::Clamp)};
P::clamp)};
}
};

View File

@@ -17,6 +17,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
namespace ck_tile::core::arch::mma {
@@ -31,21 +32,21 @@ namespace ck_tile::core::arch::mma {
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -57,21 +58,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -83,21 +84,21 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -109,21 +110,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -135,30 +136,31 @@ struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness
bit_cast<int32x2_t>(aVec),
true, // B signedness
bit_cast<int32x2_t>(bVec),
cVec,
CtrlFlags::Clamp)};
P::clamp)};
}
};
@@ -166,21 +168,21 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTar
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -193,21 +195,21 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -220,21 +222,21 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -247,21 +249,21 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
@@ -274,30 +276,31 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12";
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12(true, // A signedness
bit_cast<int32_t>(aVec),
true, // B signedness
bit_cast<int32_t>(bVec),
cVec,
CtrlFlags::Clamp)};
P::clamp)};
}
};
@@ -305,30 +308,31 @@ struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, Compi
* @struct amdgcn_mma
* @brief Specialization of amdgcn_wmma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 template <amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12";
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
template <typename... Params>
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
{
using P = WarpGemmParamsParser<Params...>;
return {__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12(true, // A signedness
bit_cast<int32x2_t>(aVec),
true, // B signedness
bit_cast<int32x2_t>(bVec),
cVec,
CtrlFlags::Clamp)};
P::clamp)};
}
};

View File

@@ -52,7 +52,6 @@ struct MmaDefaultSelector<ADataType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultWmmaCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>::SelectedOp;
};

View File

@@ -4,8 +4,6 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <stdio.h>
#include <type_traits>
namespace ck_tile::core::arch::mma {
@@ -50,25 +48,4 @@ struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
template <typename MmaOp>
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
/**
* @struct DefaultWmmaCtrlFlags
* @brief Default WMMA control flags for dense and sparse WMMA operations.
*/
struct DefaultWmmaCtrlFlags
{
constexpr static bool Clamp = false;
// Only has an effect on gfx11 when the accumulator is 16-bit.
// Determines which half of the 32-bit accum register to use for the 16-bit result.
// false = low bits [15:0], true = high bits [31:16]
constexpr static bool UseHighAccumBits = true;
};
CK_TILE_HOST_DEVICE void print_flags(DefaultWmmaCtrlFlags const& ctrlFlags)
{
printf("CtrlFlags Clamp / UseHighAccumBits : %d / %d\n",
ctrlFlags.Clamp,
ctrlFlags.UseHighAccumBits);
}
} // namespace ck_tile::core::arch::mma

View File

@@ -9,21 +9,6 @@
namespace ck_tile::core::arch::mma {
/**
* @struct DuplicateTransform
* @brief Transform to duplicate low register elements to high register elements
*/
struct DuplicateTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
{
// TODO: Implement duplication logic to broadcast low
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
return std::forward<VecType>(v);
}
};
/**
* @struct PadTransform
* @brief Transform to pad data from original type to b32 type
@@ -59,8 +44,8 @@ struct UnpadTransform
*/
struct MmaDefaultTransformsGfx11
{
using ATransform = DuplicateTransform;
using BTransform = DuplicateTransform;
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PadTransform;
using DTransform = UnpadTransform;
};

View File

@@ -83,6 +83,21 @@ struct SwapReuse_ : bool_constant<Value>
{
};
template <index_t Value>
struct Cbsz : number<Value>
{
};
template <index_t Value>
struct Abid : number<Value>
{
};
template <index_t Value>
struct Blgp : number<Value>
{
};
struct WarpGemmDefaultParams
{
using clamp = bool_constant<false>;
@@ -94,6 +109,9 @@ struct WarpGemmDefaultParams
using swap_reuse = bool_constant<false>; // internal use only
using scale_a = number<0>;
using scale_b = number<0>;
using cbsz = number<0>;
using abid = number<0>;
using blgp = number<0>;
};
template <typename T, template <index_t> class Tag>
@@ -151,6 +169,9 @@ class WarpGemmParamsParser
public:
static constexpr bool clamp = extract<Clamp, WarpGemmDefaultParams::clamp>();
static constexpr bool post_nop = extract<PostNop, WarpGemmDefaultParams::post_nop>();
static constexpr index_t cbsz = extract<Cbsz, WarpGemmDefaultParams::cbsz>();
static constexpr index_t abid = extract<Abid, WarpGemmDefaultParams::abid>();
static constexpr index_t blgp = extract<Blgp, WarpGemmDefaultParams::blgp>();
static constexpr bool reuse_a = swap_reuse ? raw_reuse_b : raw_reuse_a;
static constexpr bool reuse_b = swap_reuse ? raw_reuse_a : raw_reuse_b;
static constexpr index_t op_sel_a = swap_reuse ? raw_op_sel_b : raw_op_sel_a;

View File

@@ -87,6 +87,3 @@ if(GPU_TARGETS MATCHES "gfx120")
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -1,66 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdint>
#include <gtest/gtest.h>
#include <iostream>
#include <numeric>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
namespace {
using namespace ck_tile::core::arch::mma;
}
TEST(MmaPipelineOptionFlagsTests, ConversionTests)
{
MmaPipelineOptionFlags flags_0{};
MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap};
MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A};
MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this
EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE));
}
TEST(MmaPipelineOptionFlagsTests, OperatorsTests)
{
MmaPipelineOptionFlags flags{};
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE));
flags |= MmaPipelineOptionFlag::ABSwap;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
flags |= MmaPipelineOptionFlag::COMPRESS_A;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
flags &= MmaPipelineOptionFlag::COMPRESS_A;
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap));
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A));
}

View File

@@ -40,7 +40,6 @@ void ScaleMfmaGfx950Specialization_impl()
WaveTileM,
WaveTileN,
WaveTileK,
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
@@ -79,10 +78,7 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
}
// TODO: It seems like the ExecSignature concept (and hence MmaOpI) can not be made to work for a
// templated device function for some reason. Disable test for now and fix this once we are using
// the variadic template pack for flags...
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
template <typename AType,
typename BType,
typename CType,
@@ -97,7 +93,6 @@ void TestConceptRequirements_impl()
WaveTileM,
WaveTileN,
WaveTileK,
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
@@ -107,7 +102,7 @@ void TestConceptRequirements_impl()
TEST(ScaleMMATrait, TestConceptRequirements)
{
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
@@ -216,9 +211,7 @@ struct ScalePipelineKernel
constexpr int32_t replicate_byte = 0x01010101;
ScaleAType scale_a = 126u * replicate_byte;
ScaleBType scale_b = 129u * replicate_byte;
static constexpr index_t opselA = 0;
static constexpr index_t opselB = 0;
Pipeline::template exec<opselA, opselB>(a, b, c, scale_a, scale_b);
Pipeline::template exec<OpSelA<0>, OpSelB<0>>(a, b, c, scale_a, scale_b);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor));
}
@@ -399,9 +392,7 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
// constexpr int32_t replicate_byte = 0x01010101;
// ScaleAType scale_a = 126u * replicate_byte;
// ScaleBType scale_b = 129u * replicate_byte;
// static constexpr index_t opselA = 0;
// static constexpr index_t opselB = 0;
// Pipeline::template exec<opselA, opselB>(a, b, c, scale_a, scale_b);
// Pipeline::template exec<OpSelA<0>, OpSelB<0>>(a, b, c, scale_a, scale_b);
// __builtin_memcpy(
// static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor));
// }

View File

@@ -39,7 +39,6 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
@@ -60,7 +59,6 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration)
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
@@ -83,7 +81,6 @@ TEST(SparseMMATrait, TestConceptRequirements)
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
EXPECT_TRUE(MmaOpI<TestSparseMmma>);
@@ -95,15 +92,8 @@ TEST(SparseMMATrait, TestConceptRequirements)
TEST(SparseMMATrait, DenseVsSparseDistinction)
{
// Dense MFMA from mfma/mfma_gfx9.hpp
using DenseMfma = amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
DefaultMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::DENSE>;
using DenseMfma =
amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTargetGfx950, MmaOpFamily::DENSE>;
// Sparse MFMA on GFX950
using SparseMfma = amdgcn_mma<fp16_t,
@@ -112,7 +102,6 @@ TEST(SparseMMATrait, DenseVsSparseDistinction)
16u,
16u,
32u,
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;

View File

@@ -27,9 +27,6 @@ using namespace ck_tile::core::arch::testing;
constexpr uint32_t DummyTargetIdVal = 55555u;
using DummyCompilerTarget = amdgcn_target<static_cast<amdgcn_target_id>(DummyTargetIdVal)>;
struct DummyOpType;
struct DummyCtrlFlags
{
};
/** @brief Returns true if the given target id matches the dummy */
constexpr bool is_dummy_target(DummyCompilerTarget dummy)
@@ -49,7 +46,7 @@ using enable_if_target_id_dummy_t = std::enable_if_t<is_dummy_target(CompilerTar
template <typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
: amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, 64u, 1, 1, 1, 1, 1, 1, 1, DummyOpType, MmaOpFamily::DENSE>
// clang-format on
{
@@ -63,15 +60,8 @@ struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTa
// Have an alias so we can test supported arch vs unsupported arch
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
template <typename CompilerTarget>
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
fp32_t,
fp32_t,
8u,
8u,
8u,
DummyCtrlFlags,
CompilerTarget,
MmaOpFamily::DENSE>;
using DummyAmdgcnMma =
amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, CompilerTarget, MmaOpFamily::DENSE>;
/*! @struct MmaDefaultSelector
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both

View File

@@ -23,6 +23,7 @@
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
@@ -256,12 +257,10 @@ struct MmaLayoutTestKernel
{
// The actual scale is computed as pow(2, scale - 127), so:
// 125 -> 2^-2 and 129 -> 2^2.
int scale_A = 125;
int scale_B = 129;
static constexpr index_t opselA = 0;
static constexpr index_t opselB = 0;
c_frag =
MmaOp::template exec<opselA, opselB>(a_frag, b_frag, c_frag, scale_A, scale_B);
int scale_A = 125;
int scale_B = 129;
c_frag = MmaOp::template exec<OpSelA<0>, OpSelB<0>>(
a_frag, b_frag, c_frag, scale_A, scale_B);
}
else
{
@@ -357,145 +356,145 @@ void run_mma_layout_test()
// available on all gfx9 (gfx908, gfx90a, gfx942, gfx950)
using Gfx9CommonIntrinsics = ::testing::Types<
amdgcn_mma<F32, F32, F32, 32u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
amdgcn_mma<F32, F32, F32, 64u, 32u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
amdgcn_mma<F32, F32, F32, 16u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
amdgcn_mma<F32, F32, F32, 64u, 16u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
amdgcn_mma<F32, F32, F32, 4u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
amdgcn_mma<F32, F32, F32, 64u, 4u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
amdgcn_mma<F32, F32, F32, 32u, 32u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2f32
amdgcn_mma<F32, F32, F32, 16u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f32
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
amdgcn_mma<F16, F16, F32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<I8, I8, I32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
amdgcn_mma<I8, I8, I32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
amdgcn_mma<I8, I8, I32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
amdgcn_mma<I8, I8, I32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
amdgcn_mma<I8, I8, I32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_4x4x4i8
amdgcn_mma<I8, I8, I32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_4x4x4i8
amdgcn_mma<F32, F32, F32, 32u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
amdgcn_mma<F32, F32, F32, 64u, 32u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
amdgcn_mma<F32, F32, F32, 16u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
amdgcn_mma<F32, F32, F32, 64u, 16u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
amdgcn_mma<F32, F32, F32, 4u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
amdgcn_mma<F32, F32, F32, 64u, 4u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
amdgcn_mma<F32, F32, F32, 32u, 32u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2f32
amdgcn_mma<F32, F32, F32, 16u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f32
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
amdgcn_mma<F16, F16, F32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<I8, I8, I32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
amdgcn_mma<I8, I8, I32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
amdgcn_mma<I8, I8, I32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
amdgcn_mma<I8, I8, I32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
amdgcn_mma<I8, I8, I32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_4x4x4i8
amdgcn_mma<I8, I8, I32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_i32_4x4x4i8
>;
using Gfx908andGfx90aIntrinsics = ::testing::Types<
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8bf16
amdgcn_mma<I8, I8, I32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x8i8
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_16x16x16i8
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8bf16
amdgcn_mma<I8, I8, I32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x8i8
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE> // mfma_i32_16x16x16i8
>;
using Gfx90aAndHigherIntrinsics = ::testing::Types<
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8bf16_1k
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16bf16_1k
amdgcn_mma<F64, F64, F64, 16u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_16x16x4f64
amdgcn_mma<F64, F64, F64, 4u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_4x4x4f64
amdgcn_mma<F64, F64, F64, 16u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_f64_4x4x4f64
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8bf16_1k
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16bf16_1k
amdgcn_mma<F64, F64, F64, 16u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_16x16x4f64
amdgcn_mma<F64, F64, F64, 4u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_4x4x4f64
amdgcn_mma<F64, F64, F64, 16u, 4u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_f64_4x4x4f64
>;
using Gfx942AndHigherIntrinsics = ::testing::Types<
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x32_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x16_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_fp8
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x64_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x32_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x32_fp8_fp8
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x32_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x16_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_fp8
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x64_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x32_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x32_fp8_fp8
>;
using Gfx942Intrinsics = ::testing::Types<
amdgcn_mma<TF32, TF32, F32, 16u, 16u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8_xf32
amdgcn_mma<TF32, TF32, F32, 32u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_f32_32x32x4_xf32
amdgcn_mma<TF32, TF32, F32, 16u, 16u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8_xf32
amdgcn_mma<TF32, TF32, F32, 32u, 32u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_f32_32x32x4_xf32
>;
using Gfx950Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf16
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_f16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x32_i8
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F6, F6, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<BF6, BF6, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F6, F6, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<BF6, BF6, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F16, F16, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x128_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x64_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x64_fp8_fp8
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf16
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_f16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x32_i8
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F6, F6, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<BF6, BF6, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F6, F6, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<BF6, BF6, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<F16, F16, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_f16
amdgcn_mma<F16, F16, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_f16
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf16
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf16
amdgcn_mma<I8, I8, I32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x128_i8
amdgcn_mma<I8, I8, I32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x64_i8
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_bf8
amdgcn_mma<BF8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_fp8
amdgcn_mma<F8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_bf8
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_fp8
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_bf8
amdgcn_mma<BF8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_fp8
amdgcn_mma<F8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_fp8_bf8
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x64_fp8_fp8
>;
using Gfx11Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu4_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu4_w32
>;
using Gfx12Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12
amdgcn_mma<F8, F8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12
amdgcn_mma<F8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12
amdgcn_mma<BF8, F8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12
amdgcn_mma<F8, F8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12
amdgcn_mma<F8, BF8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12
amdgcn_mma<BF8, F8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
>;
// clang-format on