mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[rocm-libraries] ROCm/rocm-libraries#8227 (commit 75c30d5)
=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?= =?UTF-8?q?Remove=20unification=20Flag=20structs=20in=20favor=20of=20new?= =?UTF-8?q?=20WarpGemmParams=20(#8227)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Recently, the way flags are sent down to the intrinsics was changed in CK Tile. At the point where the WarpGemm is invoked, an arbitrary number of template parameters can be passed, and these are passed down all the way to the lowest level intrinsics wrappers. Here `WarpGemmParamsParser<>` is used to extract flags for the intrinsics. In this MR we adapt the the unification framework (amdgcn_mma struct and MmaPipelines) to work in the same way. By doing this, there is no longer a point in our custom intrinsic Flag structs, so these are removed. Unrelated but I also tried removing the MmaPipeline flags because they arn't used for anything except CTranspose, which is already available. This also make test_amdgcn_mma_pipeline completely redundant so removed that as well.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
621697af8c
commit
2089713f94
@@ -73,17 +73,17 @@ int check_tile_distr_enc()
|
||||
// List of intrinsics to test.
|
||||
// clang-format off
|
||||
using Intrinsics = ck_tile::tuple<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE> // wmma_i32_16x16x32_iu4_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::DENSE> // wmma_i32_16x16x32_iu4_w32_gfx12
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
|
||||
@@ -246,25 +246,14 @@ CK_TILE_HOST_DEVICE constexpr const char* to_string(Unsupported) { return "Unsup
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept HasExecSignature
|
||||
* @brief Helper concept for exec signature check.
|
||||
*/
|
||||
template <typename MmaOp, typename... ExecArgs>
|
||||
concept HasExecSignature = requires {
|
||||
{
|
||||
MmaOp::exec(typename MmaOp::AVecType{},
|
||||
typename MmaOp::BVecType{},
|
||||
typename MmaOp::CVecType{},
|
||||
std::declval<ExecArgs>()...)
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
// TODO: Make sure this actually matches amdgcn_mma.
|
||||
// NOTE: It is no longer possible to perform a check on the exec() function, since it is now
|
||||
// templated over the variadic WarpGemmParams template pack for intrinsic flags. It seems like
|
||||
// concepts do not work for templated device functions.
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
@@ -287,7 +276,7 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
|
||||
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
@@ -305,7 +294,6 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
|
||||
* @tparam CtrlFlags Control flags for mma operation
|
||||
* @tparam CompilerTarget The current compiler target
|
||||
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
|
||||
* @tparam Enabler SFINAE enabler
|
||||
@@ -316,7 +304,6 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
@@ -326,6 +313,7 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
|
||||
// clang-format on
|
||||
{
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
@@ -347,7 +335,6 @@ template <typename ADataType,
|
||||
std::uint32_t FragM,
|
||||
std::uint32_t FragN,
|
||||
std::uint32_t FragK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
@@ -357,7 +344,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
Enabler> const& mmaObj)
|
||||
@@ -392,10 +378,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
|
||||
printf(" kCNBlocks : %d\n", mmaObj.kCNBlocks);
|
||||
printf(" CBlockDimInVecDim : %d\n", mmaObj.CBlockDimInVecDim);
|
||||
printf("Instruction name : %s\n", ObjType::instruction_name);
|
||||
if constexpr(!std::is_same_v<CtrlFlags, void>)
|
||||
{
|
||||
print_flags(CtrlFlags{});
|
||||
}
|
||||
print(CompilerTarget{});
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,6 @@ struct MmaDefaultSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
};
|
||||
|
||||
@@ -3,16 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <stdio.h>
|
||||
#include <type_traits>
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -56,42 +48,4 @@ struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @struct DefaultMfmaCtrlFlags
|
||||
* @brief Default MFMA flags, no broadcasting or rotation of inputs
|
||||
* @note For f64 MFMA instructions, CBSZ and ABID are ignored and BLGP is repurposed for matrix
|
||||
* negation. BLGP bits [0:2] negate the A, B, and C input matrices respectively (ref. ISA docs for
|
||||
* MI300 Instinct).
|
||||
*/
|
||||
struct DefaultMfmaCtrlFlags
|
||||
{
|
||||
static constexpr int32_t Cbsz = 0; // CBSZ flag, default 0
|
||||
static constexpr int32_t Abid = 0; // ABID flag, default 0
|
||||
static constexpr int32_t Blgp = 0; // BLGP flag, default 0
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print_flags(DefaultMfmaCtrlFlags const& ctrlFlags)
|
||||
{
|
||||
printf("CtrlFlags Cbsz / Abid / Blgp : %" PRId32 " / %" PRId32 " / %" PRId32 "\n",
|
||||
ctrlFlags.Cbsz,
|
||||
ctrlFlags.Abid,
|
||||
ctrlFlags.Blgp);
|
||||
}
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept CtrlFlagsGfx9I
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for Gfx9 MFMA instructions
|
||||
{ CtrlFlags::Cbsz } -> std::convertible_to<int32_t>;
|
||||
{ CtrlFlags::Abid } -> std::convertible_to<int32_t>;
|
||||
{ CtrlFlags::Blgp } -> std::convertible_to<int32_t>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -16,88 +16,11 @@
|
||||
#endif
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaPipelineOptionFlag
|
||||
* @brief Individual option flags for configuring MmaPipeline behavior.
|
||||
*/
|
||||
enum struct MmaPipelineOptionFlag : unsigned
|
||||
{
|
||||
NONE = 0x0, ///< No flags set
|
||||
ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output
|
||||
COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaPipelineOptionFlags
|
||||
* @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values.
|
||||
* @par Provides bitwise OR, AND, NOT, and equality operators for composing
|
||||
* and querying pipeline option flags.
|
||||
*/
|
||||
struct MmaPipelineOptionFlags
|
||||
{
|
||||
using Type = std::underlying_type_t<MmaPipelineOptionFlag>;
|
||||
|
||||
explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {}
|
||||
explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {}
|
||||
constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag))
|
||||
{
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original)
|
||||
: mFlags(original.mFlags)
|
||||
{
|
||||
}
|
||||
|
||||
constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue)
|
||||
{
|
||||
mFlags |= toType(addValue);
|
||||
return *this;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result |= addValue;
|
||||
return result;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue)
|
||||
{
|
||||
mFlags &= toType(maskValue);
|
||||
return *this;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result &= maskValue;
|
||||
return result;
|
||||
}
|
||||
constexpr MmaPipelineOptionFlags operator~() const
|
||||
{
|
||||
MmaPipelineOptionFlags result(*this);
|
||||
result.mFlags = ~result.mFlags;
|
||||
return result;
|
||||
}
|
||||
constexpr bool testFlag(MmaPipelineOptionFlag flag) const
|
||||
{
|
||||
return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag;
|
||||
}
|
||||
constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); }
|
||||
constexpr bool operator==(Type rhs) const { return mFlags == rhs; }
|
||||
|
||||
private:
|
||||
Type mFlags;
|
||||
static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast<Type>(f); }
|
||||
};
|
||||
|
||||
constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs)
|
||||
{
|
||||
return rhs == lhs;
|
||||
}
|
||||
|
||||
/**
|
||||
* @class MmaPipelineBase
|
||||
* @brief CRTP base class that implements the common Mma pipeline logic shared by
|
||||
* all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.).
|
||||
*
|
||||
* @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling
|
||||
* pipeline behavior (e.g., C transposition, A compression).
|
||||
* @tparam Derived The concrete CRTP-derived pipeline class. Must expose:
|
||||
* - Type aliases: @c AWarpTensor, @c BWarpTensor, @c CWarpTensor, @c MmaOp
|
||||
* - Transform aliases: @c ATransform, @c BTransform, @c CTransform, @c DTransform
|
||||
@@ -107,14 +30,11 @@ constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOpt
|
||||
* 1. Apply pre-transforms to input buffers (A, B, C).
|
||||
* 2. Delegate to @c Derived::execImpl for the actual mma loop.
|
||||
* 3. Apply post-transform to output buffer (D).
|
||||
* When @c ABSwap is set, the A and B inputs are swapped before step 1.
|
||||
* When CTranspose is used, the A and B inputs are swapped before step 1.
|
||||
*/
|
||||
// TODO: c++20: use MmaPipelineOptionFlags directly
|
||||
template <MmaPipelineOptionFlags::Type Flags_, typename Derived>
|
||||
template <typename Derived>
|
||||
struct MmaPipelineBase
|
||||
{
|
||||
static constexpr auto Flags = MmaPipelineOptionFlags(Flags_);
|
||||
|
||||
/**
|
||||
* @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output).
|
||||
* @tparam ATensor Type of the A WaveTile tensor (static_distributed_tensor).
|
||||
@@ -125,17 +45,17 @@ struct MmaPipelineBase
|
||||
* @param accum Input/output accumulator WaveTile C.
|
||||
* @return The output WaveTile D after accumulation and post-transform.
|
||||
*/
|
||||
template <typename ATensor, typename BTensor, typename CTensor>
|
||||
template <typename... Params, typename ATensor, typename BTensor, typename CTensor>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(ATensor& a, BTensor& b, CTensor& accum)
|
||||
{
|
||||
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
|
||||
{
|
||||
if constexpr(Flags & MmaPipelineOptionFlag::ABSwap)
|
||||
if constexpr(Derived::CTranspose)
|
||||
{
|
||||
decltype(auto) a_transformed = Derived::ATransform::exec(b);
|
||||
decltype(auto) b_transformed = Derived::BTransform::exec(a);
|
||||
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
|
||||
Derived::execImpl(a_transformed, b_transformed, c_transformed);
|
||||
Derived::template execImpl<Params...>(a_transformed, b_transformed, c_transformed);
|
||||
return Derived::DTransform::exec(c_transformed);
|
||||
}
|
||||
else
|
||||
@@ -143,7 +63,7 @@ struct MmaPipelineBase
|
||||
decltype(auto) a_transformed = Derived::ATransform::exec(a);
|
||||
decltype(auto) b_transformed = Derived::BTransform::exec(b);
|
||||
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
|
||||
Derived::execImpl(a_transformed, b_transformed, c_transformed);
|
||||
Derived::template execImpl<Params...>(a_transformed, b_transformed, c_transformed);
|
||||
return Derived::DTransform::exec(c_transformed);
|
||||
}
|
||||
}
|
||||
@@ -153,7 +73,7 @@ struct MmaPipelineBase
|
||||
// Code should not reach here, but HOST/DEVICE compile passes are
|
||||
// weirdly intertwined and instead of having constexpr in the calling
|
||||
// site (tests) we do this. See also changes by this commit.
|
||||
return Derived::MmaOp::exec({}, {}, {});
|
||||
return Derived::MmaOp::template exec<Params...>({}, {}, {});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,11 +82,10 @@ struct MmaPipelineBase
|
||||
template <typename... Params, typename CTensor, typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c, ATensor& a, const BTensor& b) const
|
||||
{
|
||||
exec(a, b, c);
|
||||
exec<Params...>(a, b, c);
|
||||
}
|
||||
|
||||
template <index_t opselA,
|
||||
index_t opselB,
|
||||
template <typename... Params,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
typename CTensor,
|
||||
@@ -180,7 +99,7 @@ struct MmaPipelineBase
|
||||
|
||||
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
|
||||
{
|
||||
if constexpr(Flags & MmaPipelineOptionFlag::ABSwap)
|
||||
if constexpr(Derived::CTranspose)
|
||||
{
|
||||
// TODO: Figure out which combination of a/b, scale_A/B, and opselA/B needs to be
|
||||
// AB-swapped in order to get correct results. Note that WarpGemmParamsParser
|
||||
@@ -188,7 +107,7 @@ struct MmaPipelineBase
|
||||
decltype(auto) a_transformed = Derived::ATransform::exec(b);
|
||||
decltype(auto) b_transformed = Derived::BTransform::exec(a);
|
||||
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
|
||||
Derived::template execImpl<opselA, opselB>(
|
||||
Derived::template execImpl<Params...>(
|
||||
a_transformed, b_transformed, c_transformed, scale_A, scale_B);
|
||||
return Derived::DTransform::exec(c_transformed);
|
||||
}
|
||||
@@ -197,7 +116,7 @@ struct MmaPipelineBase
|
||||
decltype(auto) a_transformed = Derived::ATransform::exec(a);
|
||||
decltype(auto) b_transformed = Derived::BTransform::exec(b);
|
||||
decltype(auto) c_transformed = Derived::CTransform::exec(accum);
|
||||
Derived::template execImpl<opselA, opselB>(
|
||||
Derived::template execImpl<Params...>(
|
||||
a_transformed, b_transformed, c_transformed, scale_A, scale_B);
|
||||
return Derived::DTransform::exec(c_transformed);
|
||||
}
|
||||
@@ -219,8 +138,7 @@ struct MmaPipelineBase
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
exec<P::op_sel_a, P::op_sel_b>(a, b, c, a_scale, b_scale);
|
||||
exec<Params...>(a, b, c, a_scale, b_scale);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -232,8 +150,8 @@ struct MmaPipelineBase
|
||||
* @concept MmaPipelineI
|
||||
* @brief Expresses the meta-data interface required for a CRTP MmaPipeline.
|
||||
*/
|
||||
template <typename Derived, MmaPipelineOptionFlags::Type Flags>
|
||||
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Flags, Derived>>;
|
||||
template <typename Derived>
|
||||
concept MmaPipelineInterface = std::derived_from<Derived, MmaPipelineBase<Derived>>;
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ struct MmaDefaultSelector
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>;
|
||||
};
|
||||
@@ -88,7 +87,6 @@ template <typename ADataType,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileKTest,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget, // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
MmaOpFamily OpFamily>
|
||||
struct MmaKSearchSelector
|
||||
@@ -102,7 +100,6 @@ struct MmaKSearchSelector
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily>;
|
||||
|
||||
@@ -118,7 +115,6 @@ struct MmaKSearchSelector
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest / 2u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp>;
|
||||
};
|
||||
@@ -128,7 +124,6 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
uint32_t WaveTileM,
|
||||
uint32_t WaveTileN,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget, // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
MmaOpFamily OpFamily>
|
||||
struct MmaKSearchSelector<ADataType,
|
||||
@@ -137,20 +132,12 @@ struct MmaKSearchSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
0u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily>
|
||||
{
|
||||
// Recursion endpoint: unsupported default implementation.
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily>;
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget, OpFamily>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -6,11 +6,8 @@
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "mfma/mfma_traits.hpp"
|
||||
#include "scale/scale_traits.hpp"
|
||||
#include "sparse/sparse_traits.hpp"
|
||||
#include "wmma/wmma_traits.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdio.h>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -61,7 +58,6 @@ struct MmaOpTraits;
|
||||
* @tparam FragM_ Size of the M dimension
|
||||
* @tparam FragN_ Size of the N dimension
|
||||
* @tparam FragK_ Size of the K dimension
|
||||
* @tparam CtrlFlags_ Control flags for the MMA operation
|
||||
* @tparam CompilerTarget_ The compiler target
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
@@ -70,7 +66,6 @@ template <typename ADataType_,
|
||||
uint32_t FragM_,
|
||||
uint32_t FragN_,
|
||||
uint32_t FragK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
@@ -80,7 +75,6 @@ struct MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>>
|
||||
{
|
||||
@@ -90,12 +84,10 @@ struct MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>;
|
||||
|
||||
// Capture incoming template parameters not already in amdgcn
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
|
||||
@@ -115,7 +107,6 @@ template <typename ADataType_,
|
||||
uint32_t FragM_,
|
||||
uint32_t FragN_,
|
||||
uint32_t FragK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
@@ -124,7 +115,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>> const& traitsObj)
|
||||
{
|
||||
@@ -134,7 +124,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits<amdgcn_mma<ADataType_,
|
||||
FragM_,
|
||||
FragN_,
|
||||
FragK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>{});
|
||||
printf(
|
||||
|
||||
@@ -28,15 +28,6 @@ enum struct MmaAccumPolicy
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
namespace dense::wavewise::detail {
|
||||
// TODO: c++20: return MmaPipelineOptionFlags directly
|
||||
template <bool SwapAB>
|
||||
constexpr inline int getPipelineFlags()
|
||||
{
|
||||
return static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
|
||||
}
|
||||
} // namespace dense::wavewise::detail
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation
|
||||
@@ -50,7 +41,7 @@ constexpr inline int getPipelineFlags()
|
||||
* @tparam WaveTileN Mma WaveTile N dimension
|
||||
* @tparam WaveTileK Mma WaveTile K dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation.
|
||||
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
|
||||
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
|
||||
@@ -72,7 +63,7 @@ template <typename ADataType_,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
bool CTranspose = false,
|
||||
bool CTranspose_ = false,
|
||||
index_t SwizzleFactor = 1,
|
||||
index_t AttrNumAccessAV = 1,
|
||||
index_t AttrNumAccessBV = AttrNumAccessAV,
|
||||
@@ -92,11 +83,12 @@ template <typename ADataType_,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<CTranspose>(), WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
struct WaveWiseMmaPipeline : public MmaPipelineBase<WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<dense::wavewise::detail::getPipelineFlags<CTranspose>(), WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
using Base = MmaPipelineBase<WaveWiseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
using MmaOp = MmaOp_;
|
||||
using MmaOp = MmaOp_;
|
||||
static constexpr bool CTranspose = CTranspose_;
|
||||
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
@@ -185,7 +177,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
|
||||
static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK");
|
||||
|
||||
// TODO: Why does this even need to be a template? The types should be known.
|
||||
template <typename ATensor, typename BTensor, typename CTensor>
|
||||
template <typename... Params, typename ATensor, typename BTensor, typename CTensor>
|
||||
CK_TILE_DEVICE static void execImpl(ATensor& a, BTensor& b, CTensor& c)
|
||||
{
|
||||
static_assert(
|
||||
@@ -205,9 +197,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::exec(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn));
|
||||
c_buf.at(bm * FragsN + bn) =
|
||||
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,9 +213,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::exec(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn));
|
||||
c_buf.at(bm * FragsN + bn) =
|
||||
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -23,14 +24,13 @@ namespace ck_tile::core::arch::mma {
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -38,19 +38,20 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -62,14 +63,13 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -77,19 +77,20 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -101,14 +102,13 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -116,10 +116,11 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(aVec);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(bVec);
|
||||
|
||||
@@ -129,9 +130,9 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -143,33 +144,33 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, Compile
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp6x16_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
|
||||
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -182,33 +183,33 @@ struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, C
|
||||
* This specialization implements the Scale MFMA instruction for pk_bf6x16_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
|
||||
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -221,14 +222,13 @@ struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 16u, 16u, 128u, CtrlFlags, C
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -236,19 +236,20 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -260,14 +261,13 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -275,19 +275,20 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -299,14 +300,13 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
// clang-format on
|
||||
@@ -314,10 +314,11 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(aVec);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(bVec);
|
||||
|
||||
@@ -327,9 +328,9 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -341,33 +342,33 @@ struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Compiler
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp6x16_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
|
||||
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp6x16_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
@@ -380,33 +381,33 @@ struct amdgcn_mma<pk_fp6x16_t, pk_fp6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, Co
|
||||
* This specialization implements the Scale MFMA instruction for pk_bf6x16_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_bf6x16_t, pk_bf6x16_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4";
|
||||
|
||||
template <index_t opselA, index_t opselB>
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0},
|
||||
int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0},
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_bf6x16_t>,
|
||||
opselA,
|
||||
P::op_sel_a,
|
||||
scale_A,
|
||||
opselB,
|
||||
P::op_sel_b,
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -55,7 +55,6 @@ struct MmaDefaultSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SCALE>;
|
||||
};
|
||||
|
||||
@@ -32,7 +32,7 @@ namespace ck_tile::core::arch::mma {
|
||||
* @tparam WaveTileN Mma WaveTile N dimension
|
||||
* @tparam WaveTileK Mma WaveTile K dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam SwizzleFactor Swizzlefactor for Tile Distribution Encoding calculation.
|
||||
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
|
||||
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
|
||||
@@ -47,7 +47,7 @@ template <typename ADataType_,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
bool CTranspose = false,
|
||||
bool CTranspose_ = false,
|
||||
index_t SwizzleFactor = 1,
|
||||
index_t AttrNumAccessAV = 1,
|
||||
index_t AttrNumAccessBV = AttrNumAccessAV,
|
||||
@@ -67,12 +67,13 @@ template <typename ADataType_,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
struct ScaleMmaPipeline : public MmaPipelineBase<ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
using Base = MmaPipelineBase<ScaleMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
|
||||
using MmaOp = MmaOp_; // Expose the selected MmaOp
|
||||
using MmaOp = MmaOp_; // Expose the selected MmaOp
|
||||
static constexpr bool CTranspose = CTranspose_;
|
||||
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
@@ -170,8 +171,7 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
|
||||
static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK");
|
||||
|
||||
// TODO: Why does this even need to be a template? The types should be known.
|
||||
template <index_t opselA,
|
||||
index_t opselB,
|
||||
template <typename... Params,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
typename CTensor,
|
||||
@@ -198,11 +198,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) =
|
||||
MmaOp::template exec<opselA, opselB>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
scale_A,
|
||||
scale_B);
|
||||
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
scale_A,
|
||||
scale_B);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,11 +216,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) =
|
||||
MmaOp::template exec<opselA, opselB>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
scale_A,
|
||||
scale_B);
|
||||
MmaOp::template exec<Params...>(a_buf.at(bm * FragsK + bk),
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
scale_A,
|
||||
scale_B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,85 +3,32 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_f6.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdio.h>
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace scale::detail {
|
||||
|
||||
// Utility for converting the datatype of the A or B input matrix in a scale intrinsics to the
|
||||
// appropriate datatype flag. Note that this is not the same as the flag indicating the scale
|
||||
// datatype, see ScaleDataTypeToEnum.
|
||||
template <typename T>
|
||||
struct ScaleDataTypeToFlag;
|
||||
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v = [] {
|
||||
// sizeof(T) trick to only trigger the static assert for unsupported datatypes.
|
||||
static_assert(sizeof(T) == 0, "Unsupported scale data type");
|
||||
return -1;
|
||||
}();
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<fp8_t> // e4m3 (4 exponent bits 3 mantissa bits)
|
||||
{
|
||||
static constexpr int32_t value = 0;
|
||||
};
|
||||
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v<fp8_t> = 0; // e4m3
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<bf8_t> // e5m2
|
||||
{
|
||||
static constexpr int32_t value = 1;
|
||||
};
|
||||
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v<bf8_t> = 1; // e5m2
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<pk_fp6x16_t> // e2m3
|
||||
{
|
||||
static constexpr int32_t value = 2;
|
||||
};
|
||||
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_fp6x16_t> = 2; // e2m3
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<pk_bf6x16_t> // e3m2
|
||||
{
|
||||
static constexpr int32_t value = 3;
|
||||
};
|
||||
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_bf6x16_t> = 3; // e3m2
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
|
||||
{
|
||||
static constexpr int32_t value = 4;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept ScaleMfmaDataTypeToFlag
|
||||
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
|
||||
*/
|
||||
template <typename DataTypeToFlag>
|
||||
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
|
||||
// Flag members for scale MFMA instructions
|
||||
{ DataTypeToFlag::value } -> std::convertible_to<int32_t>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
inline constexpr int32_t ScaleDataTypeToFlag_v<pk_fp4_t> = 4; // e2m1
|
||||
|
||||
} // namespace scale::detail
|
||||
|
||||
// No real flags for now, scale and opsel are handled in higher level and passed down directly.
|
||||
// OPSEL is now passed as a template arg to exec(), see mma_pipeline.hpp
|
||||
// We will soon get rid of these flags entirely in favor of variadic template packs passed down to
|
||||
// the intrinsics directly, see WarpGemmParamsParser<>.
|
||||
struct DefaultScaleMfmaCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print_flags([[maybe_unused]] DefaultScaleMfmaCtrlFlags const& ctrlFlags)
|
||||
{
|
||||
printf("CtrlFlags: (empty)\n");
|
||||
}
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -55,7 +54,6 @@ struct MmaDefaultSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>::SelectedOp;
|
||||
};
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
@@ -24,55 +24,55 @@ namespace ck_tile::core::arch::mma {
|
||||
* This specialization implements the SMFMA instruction for fp16_t A and B
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_f16";
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return __builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
idx,
|
||||
P::cbsz, // Ignore abid and use first portion Y/N
|
||||
P::abid); // Portion of idx VGPR containing idx info
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, 64u, 8, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_f16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -80,27 +80,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_bf16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -108,27 +105,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, 64u, 8, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_bf16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -136,27 +130,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX942 and
|
||||
* GFX950 architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x64_i8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -164,27 +155,24 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX942 and
|
||||
* GFX950 architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x32_i8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -192,27 +180,25 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 32u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -220,27 +206,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -248,27 +232,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -276,27 +258,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -304,27 +284,25 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -332,27 +310,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -360,27 +336,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -388,27 +362,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX942 and GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsCdna3I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -416,27 +388,24 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_f16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -444,27 +413,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_f16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -472,27 +438,24 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, 64u, 16, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -500,27 +463,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, 64u, 16, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf16";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -528,27 +488,24 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 32u, 32u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x128_i8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -556,27 +513,24 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 128u, CtrlFlags, CompilerTa
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x64_i8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -584,27 +538,25 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 32u, 32u, 64u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -612,27 +564,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -640,27 +590,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -668,27 +616,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -696,27 +642,25 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarge
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -724,27 +668,25 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -752,27 +694,25 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -780,27 +720,25 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX950
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx950I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(
|
||||
aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {
|
||||
__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -11,5 +11,4 @@ namespace ck_tile::core::arch::mma {
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
|
||||
@@ -12,16 +12,6 @@
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace sparse::detail {
|
||||
// TODO: c++20: return MmaPipelineOptionFlags directly
|
||||
template <bool SwapAB>
|
||||
constexpr inline int getPipelineFlags()
|
||||
{
|
||||
return static_cast<int>(MmaPipelineOptionFlag::COMPRESS_A) |
|
||||
static_cast<int>(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE);
|
||||
}
|
||||
} // namespace sparse::detail
|
||||
|
||||
/**
|
||||
* @class SparseMmaPipeline
|
||||
* @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation
|
||||
@@ -38,7 +28,7 @@ constexpr inline int getPipelineFlags()
|
||||
* @tparam WaveTileN Mma WaveTile N dimension
|
||||
* @tparam WaveTileK Mma WaveTile K dimension
|
||||
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
|
||||
* @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout.
|
||||
* @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation.
|
||||
* @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp.
|
||||
* @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp.
|
||||
@@ -53,7 +43,7 @@ template <typename ADataType_,
|
||||
uint32_t WaveTileN,
|
||||
uint32_t WaveTileK,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
bool CTranspose = false,
|
||||
bool CTranspose_ = false,
|
||||
index_t SwizzleFactor = 1,
|
||||
index_t AttrNumAccessAV = 1,
|
||||
index_t AttrNumAccessBV = AttrNumAccessAV,
|
||||
@@ -73,11 +63,12 @@ template <typename ADataType_,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags<CTranspose>(), SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
struct SparseMmaPipeline : public MmaPipelineBase<SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags<CTranspose>(), SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
using Base = MmaPipelineBase<SparseMmaPipeline<ADataType_, BDataType_, CDataType_, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CTranspose_, SwizzleFactor, AttrNumAccessAV, AttrNumAccessBV, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
using MmaOp = MmaOp_;
|
||||
using MmaOp = MmaOp_;
|
||||
static constexpr bool CTranspose = CTranspose_;
|
||||
|
||||
using ADataType = typename MmaOp::ADataType;
|
||||
using BDataType = typename MmaOp::BDataType;
|
||||
@@ -86,8 +77,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
|
||||
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<ADataType, ADataType_>);
|
||||
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<BDataType, BDataType_>);
|
||||
static_assert(!MmaOpTraits<MmaOp>::IsSupported || std::is_same_v<CDataType, CDataType_>);
|
||||
static_assert(!(Base::Flags & MmaPipelineOptionFlag::ABSwap),
|
||||
"Cannot transpose C in sparse intrinsics.");
|
||||
static_assert(!CTranspose, "Cannot transpose C in sparse intrinsics.");
|
||||
|
||||
// WaveTile dimensions (Used to be fragment dims but higher level expects these to include k
|
||||
// iteration!)
|
||||
@@ -180,7 +170,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
|
||||
|
||||
// ATransformResult is a big ext_vector plus idx, B and C are static_distributed tensors. Fix
|
||||
// later TODO.
|
||||
template <typename ATransformResult, typename BTensor, typename CTensor>
|
||||
template <typename... Params, typename ATransformResult, typename BTensor, typename CTensor>
|
||||
CK_TILE_DEVICE static void execImpl(ATransformResult& a, BTensor& b, CTensor& c)
|
||||
{
|
||||
static_assert(
|
||||
@@ -206,7 +196,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::exec(
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::template exec<Params...>(
|
||||
a_frags[bm][bk],
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
@@ -224,7 +214,7 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < FragsK; ++bk)
|
||||
{
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::exec(
|
||||
c_buf.at(bm * FragsN + bn) = MmaOp::template exec<Params...>(
|
||||
a_frags[bm][bk],
|
||||
b_buf.at(bn * FragsK + bk),
|
||||
c_buf.at(bm * FragsN + bn),
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#include <stdio.h>
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
#endif
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum SparseCompressionIndex
|
||||
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
|
||||
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
|
||||
*/
|
||||
enum struct SparseCompressionIndex : int
|
||||
{
|
||||
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
|
||||
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
|
||||
THIRD = 2, // Uses bits [23:16]
|
||||
FOURTH = 3, // Uses bits [31:24]
|
||||
};
|
||||
|
||||
// to_string methods for enum classes
|
||||
CK_TILE_HOST_DEVICE constexpr const char* to_string(SparseCompressionIndex compressionIndex)
|
||||
{
|
||||
switch(compressionIndex)
|
||||
{
|
||||
case SparseCompressionIndex::FIRST: return "FIRST";
|
||||
case SparseCompressionIndex::SECOND: return "SECOND";
|
||||
case SparseCompressionIndex::THIRD: return "THIRD";
|
||||
case SparseCompressionIndex::FOURTH: return "FOURTH";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
namespace sparse::detail {
|
||||
|
||||
/**
|
||||
* @struct BuiltinParams
|
||||
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
|
||||
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
|
||||
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
|
||||
* (VGPR[srcC][7..0]).
|
||||
*
|
||||
* 8-bit source data:
|
||||
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
|
||||
* (VGPR[srcC][15..0]).
|
||||
*/
|
||||
struct BuiltinParams
|
||||
{
|
||||
int UseFirstIndex; // CBSZ
|
||||
int ByteIndexToOverride; // ABID
|
||||
};
|
||||
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
// TODO c++20: designated initializers
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
return BuiltinParams{1, 0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BuiltinParams{0, static_cast<int>(Idx)};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
/**
|
||||
* @struct DefaultSparseMfmaCtrlFlags
|
||||
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
|
||||
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
|
||||
*/
|
||||
struct DefaultSparseMfmaCtrlFlags
|
||||
{
|
||||
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print_flags(DefaultSparseMfmaCtrlFlags const& ctrlFlags)
|
||||
{
|
||||
printf("CtrlFlags CompressionIndex : %s\n", to_string(ctrlFlags.CompressionIndex));
|
||||
}
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
/**
|
||||
* @concept SparseMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for sparse MFMA instructions
|
||||
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
|
||||
};
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -53,7 +53,6 @@ struct MmaDefaultSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>::SelectedOp;
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -18,20 +19,20 @@ namespace ck_tile::core::arch::mma {
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -43,20 +44,20 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -68,20 +69,20 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -93,21 +94,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -119,30 +120,31 @@ struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness
|
||||
aVec,
|
||||
true, // B signedness
|
||||
bVec,
|
||||
cVec,
|
||||
idx,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -150,21 +152,21 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -176,21 +178,21 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -202,21 +204,21 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -228,21 +230,21 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
@@ -250,51 +252,55 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32";
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32(true, // A signedness
|
||||
bit_cast<int32_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x2_t>(bVec),
|
||||
cVec,
|
||||
idx,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 64u, 32u, 32, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32";
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32(true, // A signedness
|
||||
bit_cast<int32x2_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x4_t>(bVec),
|
||||
cVec,
|
||||
idx,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
|
||||
@@ -46,20 +47,20 @@ namespace ck_tile::core::arch::mma {
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -71,20 +72,20 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -96,29 +97,30 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness
|
||||
bit_cast<int32x4_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x4_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -126,29 +128,30 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32(true, // A signedness
|
||||
bit_cast<int32x2_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x2_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -31,21 +32,21 @@ namespace ck_tile::core::arch::mma {
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -57,21 +58,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -83,21 +84,21 @@ struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -109,21 +110,21 @@ struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -135,30 +136,31 @@ struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness
|
||||
bit_cast<int32x2_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x2_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -166,21 +168,21 @@ struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTar
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -193,21 +195,21 @@ struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -220,21 +222,21 @@ struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -247,21 +249,21 @@ struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
@@ -274,30 +276,31 @@ struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12";
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12(true, // A signedness
|
||||
bit_cast<int32_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -305,30 +308,31 @@ struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 16u, CtrlFlags, Compi
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_wmma for pk_int4_t, pk_int4_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 template <amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
struct amdgcn_mma<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<pk_int4_t, pk_int4_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12";
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec)
|
||||
{
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
return {__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12(true, // A signedness
|
||||
bit_cast<int32x2_t>(aVec),
|
||||
true, // B signedness
|
||||
bit_cast<int32x2_t>(bVec),
|
||||
cVec,
|
||||
CtrlFlags::Clamp)};
|
||||
P::clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ struct MmaDefaultSelector<ADataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
};
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -50,25 +48,4 @@ struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpT
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @struct DefaultWmmaCtrlFlags
|
||||
* @brief Default WMMA control flags for dense and sparse WMMA operations.
|
||||
*/
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
constexpr static bool Clamp = false;
|
||||
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit.
|
||||
// Determines which half of the 32-bit accum register to use for the 16-bit result.
|
||||
// false = low bits [15:0], true = high bits [31:16]
|
||||
constexpr static bool UseHighAccumBits = true;
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print_flags(DefaultWmmaCtrlFlags const& ctrlFlags)
|
||||
{
|
||||
printf("CtrlFlags Clamp / UseHighAccumBits : %d / %d\n",
|
||||
ctrlFlags.Clamp,
|
||||
ctrlFlags.UseHighAccumBits);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
@@ -9,21 +9,6 @@
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DuplicateTransform
|
||||
* @brief Transform to duplicate low register elements to high register elements
|
||||
*/
|
||||
struct DuplicateTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement duplication logic to broadcast low
|
||||
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct PadTransform
|
||||
* @brief Transform to pad data from original type to b32 type
|
||||
@@ -59,8 +44,8 @@ struct UnpadTransform
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx11
|
||||
{
|
||||
using ATransform = DuplicateTransform;
|
||||
using BTransform = DuplicateTransform;
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PadTransform;
|
||||
using DTransform = UnpadTransform;
|
||||
};
|
||||
|
||||
@@ -83,6 +83,21 @@ struct SwapReuse_ : bool_constant<Value>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t Value>
|
||||
struct Cbsz : number<Value>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t Value>
|
||||
struct Abid : number<Value>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t Value>
|
||||
struct Blgp : number<Value>
|
||||
{
|
||||
};
|
||||
|
||||
struct WarpGemmDefaultParams
|
||||
{
|
||||
using clamp = bool_constant<false>;
|
||||
@@ -94,6 +109,9 @@ struct WarpGemmDefaultParams
|
||||
using swap_reuse = bool_constant<false>; // internal use only
|
||||
using scale_a = number<0>;
|
||||
using scale_b = number<0>;
|
||||
using cbsz = number<0>;
|
||||
using abid = number<0>;
|
||||
using blgp = number<0>;
|
||||
};
|
||||
|
||||
template <typename T, template <index_t> class Tag>
|
||||
@@ -151,6 +169,9 @@ class WarpGemmParamsParser
|
||||
public:
|
||||
static constexpr bool clamp = extract<Clamp, WarpGemmDefaultParams::clamp>();
|
||||
static constexpr bool post_nop = extract<PostNop, WarpGemmDefaultParams::post_nop>();
|
||||
static constexpr index_t cbsz = extract<Cbsz, WarpGemmDefaultParams::cbsz>();
|
||||
static constexpr index_t abid = extract<Abid, WarpGemmDefaultParams::abid>();
|
||||
static constexpr index_t blgp = extract<Blgp, WarpGemmDefaultParams::blgp>();
|
||||
static constexpr bool reuse_a = swap_reuse ? raw_reuse_b : raw_reuse_a;
|
||||
static constexpr bool reuse_b = swap_reuse ? raw_reuse_a : raw_reuse_b;
|
||||
static constexpr index_t op_sel_a = swap_reuse ? raw_op_sel_b : raw_op_sel_a;
|
||||
|
||||
@@ -87,6 +87,3 @@ if(GPU_TARGETS MATCHES "gfx120")
|
||||
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
|
||||
target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdint>
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
||||
|
||||
namespace {
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
}
|
||||
|
||||
TEST(MmaPipelineOptionFlagsTests, ConversionTests)
|
||||
{
|
||||
MmaPipelineOptionFlags flags_0{};
|
||||
MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap};
|
||||
MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A};
|
||||
MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this
|
||||
|
||||
EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
|
||||
EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
|
||||
EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
|
||||
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
}
|
||||
|
||||
TEST(MmaPipelineOptionFlagsTests, OperatorsTests)
|
||||
{
|
||||
MmaPipelineOptionFlags flags{};
|
||||
|
||||
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
|
||||
flags |= MmaPipelineOptionFlag::ABSwap;
|
||||
|
||||
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
|
||||
flags |= MmaPipelineOptionFlag::COMPRESS_A;
|
||||
|
||||
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
|
||||
flags &= MmaPipelineOptionFlag::COMPRESS_A;
|
||||
|
||||
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
|
||||
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE));
|
||||
EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap));
|
||||
EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A));
|
||||
}
|
||||
@@ -40,7 +40,6 @@ void ScaleMfmaGfx950Specialization_impl()
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SCALE>;
|
||||
|
||||
@@ -79,10 +78,7 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
|
||||
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
|
||||
}
|
||||
|
||||
// TODO: It seems like the ExecSignature concept (and hence MmaOpI) can not be made to work for a
|
||||
// templated device function for some reason. Disable test for now and fix this once we are using
|
||||
// the variadic template pack for flags...
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
@@ -97,7 +93,6 @@ void TestConceptRequirements_impl()
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SCALE>;
|
||||
|
||||
@@ -107,7 +102,7 @@ void TestConceptRequirements_impl()
|
||||
|
||||
TEST(ScaleMMATrait, TestConceptRequirements)
|
||||
{
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
|
||||
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
|
||||
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
|
||||
@@ -216,9 +211,7 @@ struct ScalePipelineKernel
|
||||
constexpr int32_t replicate_byte = 0x01010101;
|
||||
ScaleAType scale_a = 126u * replicate_byte;
|
||||
ScaleBType scale_b = 129u * replicate_byte;
|
||||
static constexpr index_t opselA = 0;
|
||||
static constexpr index_t opselB = 0;
|
||||
Pipeline::template exec<opselA, opselB>(a, b, c, scale_a, scale_b);
|
||||
Pipeline::template exec<OpSelA<0>, OpSelB<0>>(a, b, c, scale_a, scale_b);
|
||||
__builtin_memcpy(
|
||||
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor));
|
||||
}
|
||||
@@ -399,9 +392,7 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
|
||||
// constexpr int32_t replicate_byte = 0x01010101;
|
||||
// ScaleAType scale_a = 126u * replicate_byte;
|
||||
// ScaleBType scale_b = 129u * replicate_byte;
|
||||
// static constexpr index_t opselA = 0;
|
||||
// static constexpr index_t opselB = 0;
|
||||
// Pipeline::template exec<opselA, opselB>(a, b, c, scale_a, scale_b);
|
||||
// Pipeline::template exec<OpSelA<0>, OpSelB<0>>(a, b, c, scale_a, scale_b);
|
||||
// __builtin_memcpy(
|
||||
// static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor));
|
||||
// }
|
||||
|
||||
@@ -39,7 +39,6 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
@@ -60,7 +59,6 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration)
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
@@ -83,7 +81,6 @@ TEST(SparseMMATrait, TestConceptRequirements)
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
EXPECT_TRUE(MmaOpI<TestSparseMmma>);
|
||||
@@ -95,15 +92,8 @@ TEST(SparseMMATrait, TestConceptRequirements)
|
||||
TEST(SparseMMATrait, DenseVsSparseDistinction)
|
||||
{
|
||||
// Dense MFMA from mfma/mfma_gfx9.hpp
|
||||
using DenseMfma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::DENSE>;
|
||||
using DenseMfma =
|
||||
amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CompilerTargetGfx950, MmaOpFamily::DENSE>;
|
||||
|
||||
// Sparse MFMA on GFX950
|
||||
using SparseMfma = amdgcn_mma<fp16_t,
|
||||
@@ -112,7 +102,6 @@ TEST(SparseMMATrait, DenseVsSparseDistinction)
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
|
||||
@@ -27,9 +27,6 @@ using namespace ck_tile::core::arch::testing;
|
||||
constexpr uint32_t DummyTargetIdVal = 55555u;
|
||||
using DummyCompilerTarget = amdgcn_target<static_cast<amdgcn_target_id>(DummyTargetIdVal)>;
|
||||
struct DummyOpType;
|
||||
struct DummyCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
/** @brief Returns true if the given target id matches the dummy */
|
||||
constexpr bool is_dummy_target(DummyCompilerTarget dummy)
|
||||
@@ -49,7 +46,7 @@ using enable_if_target_id_dummy_t = std::enable_if_t<is_dummy_target(CompilerTar
|
||||
template <typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, 64u, 1, 1, 1, 1, 1, 1, 1, DummyOpType, MmaOpFamily::DENSE>
|
||||
// clang-format on
|
||||
{
|
||||
@@ -63,15 +60,8 @@ struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTa
|
||||
// Have an alias so we can test supported arch vs unsupported arch
|
||||
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
8u,
|
||||
8u,
|
||||
8u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
using DummyAmdgcnMma =
|
||||
amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, CompilerTarget, MmaOpFamily::DENSE>;
|
||||
|
||||
/*! @struct MmaDefaultSelector
|
||||
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
@@ -256,12 +257,10 @@ struct MmaLayoutTestKernel
|
||||
{
|
||||
// The actual scale is computed as pow(2, scale - 127), so:
|
||||
// 125 -> 2^-2 and 129 -> 2^2.
|
||||
int scale_A = 125;
|
||||
int scale_B = 129;
|
||||
static constexpr index_t opselA = 0;
|
||||
static constexpr index_t opselB = 0;
|
||||
c_frag =
|
||||
MmaOp::template exec<opselA, opselB>(a_frag, b_frag, c_frag, scale_A, scale_B);
|
||||
int scale_A = 125;
|
||||
int scale_B = 129;
|
||||
c_frag = MmaOp::template exec<OpSelA<0>, OpSelB<0>>(
|
||||
a_frag, b_frag, c_frag, scale_A, scale_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -357,145 +356,145 @@ void run_mma_layout_test()
|
||||
|
||||
// available on all gfx9 (gfx908, gfx90a, gfx942, gfx950)
|
||||
using Gfx9CommonIntrinsics = ::testing::Types<
|
||||
amdgcn_mma<F32, F32, F32, 32u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 32u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
|
||||
amdgcn_mma<F32, F32, F32, 16u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 16u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
|
||||
amdgcn_mma<F32, F32, F32, 4u, 64u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 4u, 1u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
|
||||
amdgcn_mma<F32, F32, F32, 32u, 32u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2f32
|
||||
amdgcn_mma<F32, F32, F32, 16u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f32
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<I8, I8, I32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
|
||||
amdgcn_mma<I8, I8, I32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
|
||||
amdgcn_mma<I8, I8, I32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_4x4x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_4x4x4i8
|
||||
amdgcn_mma<F32, F32, F32, 32u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 32u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x1f32
|
||||
amdgcn_mma<F32, F32, F32, 16u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 16u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x1f32
|
||||
amdgcn_mma<F32, F32, F32, 4u, 64u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
|
||||
amdgcn_mma<F32, F32, F32, 64u, 4u, 1u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x1f32
|
||||
amdgcn_mma<F32, F32, F32, 32u, 32u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2f32
|
||||
amdgcn_mma<F32, F32, F32, 16u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f32
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<I8, I8, I32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x4i8
|
||||
amdgcn_mma<I8, I8, I32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x4i8
|
||||
amdgcn_mma<I8, I8, I32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_4x4x4i8
|
||||
amdgcn_mma<I8, I8, I32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_i32_4x4x4i8
|
||||
>;
|
||||
using Gfx908andGfx90aIntrinsics = ::testing::Types<
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 2u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8bf16
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x8i8
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_i32_16x16x16i8
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 2u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x2bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8bf16
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x8i8
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE> // mfma_i32_16x16x16i8
|
||||
>;
|
||||
using Gfx90aAndHigherIntrinsics = ::testing::Types<
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16bf16_1k
|
||||
amdgcn_mma<F64, F64, F64, 16u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_16x16x4f64
|
||||
amdgcn_mma<F64, F64, F64, 4u, 16u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_4x4x4f64
|
||||
amdgcn_mma<F64, F64, F64, 16u, 4u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_f64_4x4x4f64
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 32u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 4u, 64u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 64u, 4u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_4x4x4bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x8bf16_1k
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x16bf16_1k
|
||||
amdgcn_mma<F64, F64, F64, 16u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_16x16x4f64
|
||||
amdgcn_mma<F64, F64, F64, 4u, 16u, 4u, TestTarget, MmaOpFamily::DENSE>, // mfma_f64_4x4x4f64
|
||||
amdgcn_mma<F64, F64, F64, 16u, 4u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_f64_4x4x4f64
|
||||
>;
|
||||
using Gfx942AndHigherIntrinsics = ::testing::Types<
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x32_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x16_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_fp8
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x64_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x32_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x32_fp8_fp8
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x32_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x16_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_fp8_fp8
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x32_bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x16_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x64_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x32_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x32_fp8_fp8
|
||||
>;
|
||||
using Gfx942Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<TF32, TF32, F32, 16u, 16u, 8u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8_xf32
|
||||
amdgcn_mma<TF32, TF32, F32, 32u, 32u, 4u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE> // mfma_f32_32x32x4_xf32
|
||||
amdgcn_mma<TF32, TF32, F32, 16u, 16u, 8u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x8_xf32
|
||||
amdgcn_mma<TF32, TF32, F32, 32u, 32u, 4u, TestTarget, MmaOpFamily::DENSE> // mfma_f32_32x32x4_xf32
|
||||
>;
|
||||
using Gfx950Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, DefaultMfmaCtrlFlags, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x32_i8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F6, F6, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<BF6, BF6, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F6, F6, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<BF6, BF6, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 32u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x128_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x64_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultSparseMfmaCtrlFlags, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x64_fp8_fp8
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_bf16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 16u, TestTarget, MmaOpFamily::DENSE>, // mfma_f32_32x32x16_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 64u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_16x16x64_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 32u, TestTarget, MmaOpFamily::DENSE>, // mfma_i32_32x32x32_i8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F6, F6, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<BF6, BF6, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F6, F6, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<BF6, BF6, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_f16
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x64_bf16
|
||||
amdgcn_mma<BF16, BF16, F32, 32u, 32u, 32u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x32_bf16
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_16x16x128_i8
|
||||
amdgcn_mma<I8, I8, I32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_i32_32x32x64_i8
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_16x16x128_fp8_fp8
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_bf8
|
||||
amdgcn_mma<BF8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_bf8_fp8
|
||||
amdgcn_mma<F8, BF8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE>, // smfmac_f32_32x32x64_fp8_bf8
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, TestTarget, MmaOpFamily::SPARSE> // smfmac_f32_32x32x64_fp8_fp8
|
||||
>;
|
||||
using Gfx11Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu4_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target11, MmaOpFamily::DENSE> // wmma_i32_16x16x16_iu4_w32
|
||||
>;
|
||||
using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f16_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_bf16_16x16x16_bf16_w32_gfx12
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu8_w32_gfx12
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user