mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK TILE] Unification of Scale MFMA/WMMA Policy Structs (#5857)
## Motivation The existing unification work supports DENSE and SPARSE intrinsics. In this PR, we enable support for SCALE intrinsics and add example SCALE implementations. ## Technical Details Adding MFMA SCALE intrinsics support, adding tests for MFMA SCALE intrinsics, and adding WMMA SCALE policy trait. Note: fp6 SCALE intrinsics support is not included in this PR, as its handling in ck_tile is currently more specialized and does not follow the same pattern as other datatypes. ## Test Plan Added new tests for the relevant SCALE specialisations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -25,6 +25,13 @@
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
||||
|
||||
@@ -245,7 +245,7 @@ concept MmaOpI = requires(MmaOp op) {
|
||||
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
|
||||
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int>);
|
||||
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
@@ -303,6 +303,8 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "wmma/wmma.hpp" // should be included before the below headers
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "scale/scale.hpp"
|
||||
#include "sparse/sparse.hpp"
|
||||
|
||||
@@ -207,11 +207,17 @@ struct MmaPipelineBase
|
||||
|
||||
/**
|
||||
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
|
||||
* @return A @c std::tuple of the transformed (A, B, C) vectors ready for the mma loop.
|
||||
* @return A @c std::tuple of the transformed (A, B, C, [scaleA, scaleB]) vectors ready for the
|
||||
* mma loop.
|
||||
*/
|
||||
template <typename ATransformInputs, typename BTransformInputs, typename CTransformInputs>
|
||||
CK_TILE_DEVICE static decltype(auto)
|
||||
applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum)
|
||||
template <typename ATransformInputs,
|
||||
typename BTransformInputs,
|
||||
typename CTransformInputs,
|
||||
typename... ExtraArgs>
|
||||
CK_TILE_DEVICE static decltype(auto) applyTransformsToInputs(ATransformInputs&& a,
|
||||
BTransformInputs&& b,
|
||||
CTransformInputs&& accum,
|
||||
ExtraArgs&&... extras)
|
||||
{
|
||||
using InternalAVecT = typename Derived::InternalAVecT;
|
||||
using InternalBVecT = typename Derived::InternalBVecT;
|
||||
@@ -224,19 +230,18 @@ struct MmaPipelineBase
|
||||
return std::make_tuple(
|
||||
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
|
||||
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
|
||||
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)));
|
||||
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)),
|
||||
std::forward<ExtraArgs>(extras)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
|
||||
* @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed.
|
||||
* @param c_result The accumulator to post-process.
|
||||
* @return The final D output in the user-facing vector type.
|
||||
*/
|
||||
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
|
||||
CK_TILE_DEVICE static auto
|
||||
applyTransformToOutput(std::tuple<ATransformResult, BTransformResult, CTransformResult>&& vecs)
|
||||
template <typename CTransformResult>
|
||||
CK_TILE_DEVICE static auto applyTransformToOutput(CTransformResult&& c_result)
|
||||
{
|
||||
auto&& [a_result, b_result, c_result] = vecs;
|
||||
static_assert(!is_std_tuple_v<decltype(c_result)>,
|
||||
"If CTransform returns more than the vector, update this function.");
|
||||
|
||||
@@ -270,7 +275,46 @@ struct MmaPipelineBase
|
||||
|
||||
Derived::execImpl(transformed_inputs);
|
||||
|
||||
return applyTransformToOutput(std::move(transformed_inputs));
|
||||
auto&& [a_result, b_result, c_result] = std::move(transformed_inputs);
|
||||
return applyTransformToOutput(std::move(c_result));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
|
||||
// Code should not reach here, but HOST/DEVICE compile passes are
|
||||
// weirdly intertwined and instead of having constexpr in the calling
|
||||
// site (tests) we do this. See also changes by this commit.
|
||||
return Derived::MmaOp::exec({}, {}, {});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VecTA,
|
||||
typename VecTB,
|
||||
typename VecTC,
|
||||
typename ScaleADataType,
|
||||
typename ScaleBDataType>
|
||||
CK_TILE_DEVICE static decltype(auto)
|
||||
exec(VecTA&& a, VecTB&& b, VecTC&& accum, ScaleADataType&& scale_A, ScaleBDataType&& scale_B)
|
||||
{
|
||||
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
|
||||
{
|
||||
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
|
||||
auto transformed_inputs = applyTransformsToInputs(
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
|
||||
: std::forward<VecTA>(a),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
|
||||
: std::forward<VecTB>(b),
|
||||
std::forward<VecTC>(accum),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
|
||||
: std::forward<ScaleADataType>(scale_A),
|
||||
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
|
||||
: std::forward<ScaleBDataType>(scale_B));
|
||||
|
||||
Derived::execImpl(transformed_inputs);
|
||||
|
||||
auto&& [a_result, b_result, c_result, scale_A_result, scale_B_result] =
|
||||
std::move(transformed_inputs);
|
||||
return applyTransformToOutput(std::move(c_result));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
229
include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp
Normal file
229
include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for fp8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for bf8_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
|
||||
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Scale MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
|
||||
// clang-format on
|
||||
{
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
bit_cast<int32x8_t>(aVec),
|
||||
bit_cast<int32x8_t>(bVec),
|
||||
cVec,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
|
||||
static_cast<int>(CtrlFlags::OPSEL_A),
|
||||
scale_A,
|
||||
static_cast<int>(CtrlFlags::OPSEL_B),
|
||||
scale_B)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
149
include/ck_tile/core/arch/mma/scale/mfma/selector.hpp
Normal file
149
include/ck_tile/core/arch/mma/scale/mfma/selector.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class ScaleMfmaDefaultSelector
|
||||
* @brief Implements a default scale MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension
|
||||
* @tparam WaveTileN Size of the N dimension
|
||||
* @tparam WaveTileKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
|
||||
// is_power_of_two_integer(WaveTileKTest))
|
||||
struct ScaleMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SCALE>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for scale MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
|
||||
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
amdgcn_target_id::GFX950)>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SCALE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename ScaleMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 =
|
||||
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
|
||||
static constexpr bool IsSupported32x32 =
|
||||
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
|
||||
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
10
include/ck_tile/core/arch/mma/scale/scale.hpp
Normal file
10
include/ck_tile/core/arch/mma/scale/scale.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include scale MFMA traits and architecture-specific implementations
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
77
include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp
Normal file
77
include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
std::uint32_t FragM,
|
||||
std::uint32_t FragN,
|
||||
std::uint32_t FragK,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp_ =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp_ = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SCALE>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
|
||||
// clang-format off
|
||||
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
|
||||
{
|
||||
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
|
||||
// clang-format on
|
||||
|
||||
using MmaOp = MmaOp_; // Expose the selected MmaOp
|
||||
|
||||
// Expose caller-side vector types
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Expose internal vector types
|
||||
using InternalAVecT = typename MmaOp::AVecType;
|
||||
using InternalBVecT = typename MmaOp::BVecType;
|
||||
using InternalCVecT = typename MmaOp::CVecType;
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
template <typename VecTA,
|
||||
typename VecTB,
|
||||
typename VecTC,
|
||||
typename ScaleADataType,
|
||||
typename ScaleBDataType>
|
||||
CK_TILE_DEVICE static void
|
||||
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
|
||||
{
|
||||
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
|
||||
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
6
include/ck_tile/core/arch/mma/scale/scale_selector.hpp
Normal file
6
include/ck_tile/core/arch/mma/scale/scale_selector.hpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
||||
93
include/ck_tile/core/arch/mma/scale/scale_traits.hpp
Normal file
93
include/ck_tile/core/arch/mma/scale/scale_traits.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
// #include "ck_tile/core/numeric/pk_fp6.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
namespace scale::detail {
|
||||
|
||||
template <typename T>
|
||||
struct ScaleDataTypeToFlag;
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<fp8_t> // e4m3
|
||||
{
|
||||
static constexpr std::int32_t value = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<bf8_t> // e5m2
|
||||
{
|
||||
static constexpr std::int32_t value = 1;
|
||||
};
|
||||
|
||||
// template <>
|
||||
// struct ScaleDataTypeToFlag<pk_fp6_t<1>> // e2m3
|
||||
// {
|
||||
// static constexpr std::int32_t value = 2;
|
||||
// };
|
||||
|
||||
// template <>
|
||||
// struct ScaleDataTypeToFlag<bf6_t> // e3m2
|
||||
// {
|
||||
// static constexpr std::int32_t value = 3;
|
||||
// };
|
||||
|
||||
template <>
|
||||
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
|
||||
{
|
||||
static constexpr std::int32_t value = 4;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept ScaleMfmaDataTypeToFlag
|
||||
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
|
||||
*/
|
||||
template <typename DataTypeToFlag>
|
||||
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
|
||||
// Flag members for scale MFMA instructions
|
||||
{ DataTypeToFlag::value } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
|
||||
|
||||
} // namespace scale::detail
|
||||
|
||||
struct DefaultScaleMfmaCtrlFlags
|
||||
{
|
||||
static constexpr std::int32_t OPSEL_A = 0;
|
||||
static constexpr std::int32_t OPSEL_B = 0;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @concept ScaleMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept ScaleMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for scale MFMA instructions
|
||||
{ CtrlFlags::OPSEL_A } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::OPSEL_B } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
43
include/ck_tile/core/arch/mma/scale/scale_transforms.hpp
Normal file
43
include/ck_tile/core/arch/mma/scale/scale_transforms.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsScale
|
||||
* @brief Implements the default MMA transforms for Scale
|
||||
*/
|
||||
struct MmaDefaultTransformsScale
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Specialization for Scale MFMA transforms
|
||||
* Provides default transform selection for scale operations
|
||||
*
|
||||
* @tparam MmaOp Scale MMA operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_mma_op_scale(MmaOp))
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SCALE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsScale;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -11,6 +11,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
|
||||
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
|
||||
target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
|
||||
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -10,23 +10,27 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "../get_wave_size_helper.hpp"
|
||||
|
||||
template <typename AType_ = ck_tile::fp16_t,
|
||||
typename BType_ = ck_tile::fp16_t,
|
||||
typename CType_ = ck_tile::fp32_t,
|
||||
uint32_t WaveTileM_ = 16,
|
||||
uint32_t WaveTileN_ = 16,
|
||||
uint32_t WaveTileK_ = 32>
|
||||
template <typename AType_ = ck_tile::fp16_t,
|
||||
typename BType_ = ck_tile::fp16_t,
|
||||
typename CType_ = ck_tile::fp32_t,
|
||||
uint32_t WaveTileM_ = 16,
|
||||
uint32_t WaveTileN_ = 16,
|
||||
uint32_t WaveTileK_ = 32,
|
||||
typename ScaleAType_ = int,
|
||||
typename ScaleBType_ = int>
|
||||
struct MmaPipelineTest
|
||||
{
|
||||
using AType = AType_;
|
||||
using BType = BType_;
|
||||
using CType = CType_;
|
||||
using ScaleAType = ScaleAType_;
|
||||
using ScaleBType = ScaleBType_;
|
||||
static constexpr auto WaveTileM = WaveTileM_;
|
||||
static constexpr auto WaveTileN = WaveTileN_;
|
||||
static constexpr auto WaveTileK = WaveTileK_;
|
||||
@@ -120,4 +124,109 @@ struct MmaPipelineTest
|
||||
HIP_CHECK_ERROR(hipFree(d_c));
|
||||
HIP_CHECK_ERROR(hipFree(d_out));
|
||||
}
|
||||
|
||||
void
|
||||
test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
|
||||
std::function<void(uint32_t, void*, void*, void*, void*, void*, void*)> kernel,
|
||||
std::function<CType(uint32_t, ScaleAType, ScaleBType)> getExpected,
|
||||
std::function<AType(size_t)> aInitializer = nullptr)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
|
||||
int devCount;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
|
||||
|
||||
hipDeviceProp_t devProp;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
|
||||
|
||||
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
|
||||
bool hasDevice = static_cast<bool>(devCount > 0);
|
||||
int deviceWarpSize = devProp.warpSize;
|
||||
|
||||
if(!hasDevice || shouldSkip(currentArchId))
|
||||
{
|
||||
GTEST_SKIP() << "No HIP device found. Skipping test.";
|
||||
}
|
||||
|
||||
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
|
||||
// Note: Actual FragK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t FragM = WaveTileM;
|
||||
static constexpr uint32_t FragN = WaveTileN;
|
||||
static constexpr uint32_t FragK = WaveTileK;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits<AType>::PackedSize;
|
||||
uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits<BType>::PackedSize;
|
||||
uint32_t CElements = FragM * FragN / deviceWarpSize;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
uint32_t CSize = CElements * sizeof(CType);
|
||||
uint32_t ScaleASize = 1 * sizeof(ScaleAType);
|
||||
uint32_t ScaleBSize = 1 * sizeof(ScaleBType);
|
||||
|
||||
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
|
||||
std::vector<AType> h_a(AElements);
|
||||
if(aInitializer)
|
||||
{
|
||||
for(size_t i = 0; i < AElements; ++i)
|
||||
h_a[i] = aInitializer(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1.0f));
|
||||
}
|
||||
std::vector<BType> h_b(BElements, type_convert<BType>(1.0f));
|
||||
std::vector<CType> h_c(CElements, type_convert<CType>(0.0f));
|
||||
std::vector<CType> h_out(CElements, type_convert<CType>(0.0f));
|
||||
// The actual scale is computed as pow(2, scale - 127), so:
|
||||
// 126 -> 2^-1 and 129 -> 2^2.
|
||||
ScaleAType h_scale_a = 126;
|
||||
ScaleBType h_scale_b = 129;
|
||||
|
||||
AType* d_a;
|
||||
BType* d_b;
|
||||
CType* d_c;
|
||||
CType* d_out;
|
||||
ScaleAType* d_scale_a;
|
||||
ScaleBType* d_scale_b;
|
||||
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize));
|
||||
|
||||
// Copy inputs to device
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice));
|
||||
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
// Verify output against expected value for all elements
|
||||
for(size_t i = 0; i < CElements; ++i)
|
||||
{
|
||||
EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3);
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipFree(d_a));
|
||||
HIP_CHECK_ERROR(hipFree(d_b));
|
||||
HIP_CHECK_ERROR(hipFree(d_c));
|
||||
HIP_CHECK_ERROR(hipFree(d_out));
|
||||
HIP_CHECK_ERROR(hipFree(d_scale_a));
|
||||
HIP_CHECK_ERROR(hipFree(d_scale_b));
|
||||
}
|
||||
};
|
||||
|
||||
270
test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp
Normal file
270
test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "pipeline_tests_helper.hpp"
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
|
||||
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK>
|
||||
void ScaleMfmaGfx950Specialization_impl()
|
||||
{
|
||||
using TestScaleMma = amdgcn_mma<AType,
|
||||
BType,
|
||||
CType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SCALE>;
|
||||
|
||||
static_assert(std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
|
||||
TestScaleMma::OpFamily == MmaOpFamily::SCALE,
|
||||
"GFX950 scale intrinsic should have ScaleMFMAOp type");
|
||||
|
||||
static_assert(is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>,
|
||||
"GFX950 scale intrinsic should be detected as Scale");
|
||||
|
||||
// Get its traits
|
||||
using TestTraits = MmaOpTraits<TestScaleMma>;
|
||||
|
||||
// Verify trait detection
|
||||
static_assert(TestTraits::IsScale, "Scale MMA should be detected as scale");
|
||||
static_assert(TestTraits::IsSupported, "Scale MMA specialization should be supported");
|
||||
static_assert(TestTraits::IsMfma, "Scale MFMA should be detected as MFMA");
|
||||
static_assert(!TestTraits::IsWmma, "Scale MFMA should not be detected as WMMA");
|
||||
}
|
||||
|
||||
TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
|
||||
{
|
||||
// Test fp8 → fp32 scale MFMA for GFX950 (16x16x128)
|
||||
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
|
||||
// Test bf8 → fp32 scale MFMA for GFX950 (16x16x128)
|
||||
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
|
||||
// Test fp4 → fp32 scale MFMA for GFX950 (16x16x128)
|
||||
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
|
||||
// Test fp8 → fp32 scale MFMA for GFX950 (32x32x64)
|
||||
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
|
||||
// Test bf8 → fp32 scale MFMA for GFX950 (32x32x64)
|
||||
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
|
||||
// Test fp4 → fp32 scale MFMA for GFX950 (32x32x64)
|
||||
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
|
||||
|
||||
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK>
|
||||
void TestConceptRequirements_impl()
|
||||
{
|
||||
using TestScaleMma = amdgcn_mma<AType,
|
||||
BType,
|
||||
CType,
|
||||
WaveTileM,
|
||||
WaveTileN,
|
||||
WaveTileK,
|
||||
DefaultScaleMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SCALE>;
|
||||
static_assert(MmaOpI<TestScaleMma>);
|
||||
}
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
TEST(ScaleMMATrait, TestConceptRequirements)
|
||||
{
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
|
||||
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
|
||||
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
|
||||
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
|
||||
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
|
||||
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
|
||||
#else
|
||||
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
}
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void ScaleSelector_impl()
|
||||
{
|
||||
static_for<2, 14, 6>{}([](auto k_factor) {
|
||||
static_for<1, 33, 1>{}([&](auto i) {
|
||||
using Selected = typename MmaDefaultSelector<AType,
|
||||
BType,
|
||||
CType,
|
||||
static_cast<std::uint32_t>(i),
|
||||
static_cast<std::uint32_t>(i),
|
||||
static_cast<std::uint32_t>(k_factor * i),
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SCALE>::SelectedOp;
|
||||
static constexpr bool isValid = (i == 16 && k_factor == 8) || (i == 32);
|
||||
if constexpr(isValid)
|
||||
{
|
||||
// Selector should pick a scale MFMA implementation
|
||||
static_assert(MmaOpTraits<Selected>::IsScale);
|
||||
static_assert(MmaOpTraits<Selected>::IsMfma);
|
||||
static_assert(MmaOpTraits<Selected>::IsSupported);
|
||||
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Selector should pick the unsupported pass through
|
||||
static_assert(!MmaOpTraits<Selected>::IsSupported);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ScaleMMATrait, ScaleSelector)
|
||||
{
|
||||
ScaleSelector_impl<fp8_t, fp8_t, fp32_t>();
|
||||
ScaleSelector_impl<bf8_t, bf8_t, fp32_t>();
|
||||
ScaleSelector_impl<pk_fp4_t, pk_fp4_t, fp32_t>();
|
||||
}
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename ScaleAType,
|
||||
typename ScaleBType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK>
|
||||
__global__ void
|
||||
test_scale_accum_over_k(void* a, void* b, void* c, void* out, void* scale_A, void* scale_B)
|
||||
{
|
||||
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
|
||||
|
||||
using AVecType = typename Pipeline::AVecType;
|
||||
using BVecType = typename Pipeline::BVecType;
|
||||
using CVecType = typename Pipeline::CVecType;
|
||||
|
||||
// NOTE: WaveTileK is used as a Pipeline template parameter, but the K iteration is
|
||||
// happening outside the Pipeline. This is a bit incorrect currently.
|
||||
static constexpr std::uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
|
||||
|
||||
// Initialize the accumulator
|
||||
CVecType result = *reinterpret_cast<CVecType*>(c);
|
||||
|
||||
// Accumulate input AxB over WaveTileK/FragK iterations
|
||||
for(std::uint32_t i = 0; i < kIters; ++i)
|
||||
{
|
||||
result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
|
||||
*reinterpret_cast<BVecType*>(b),
|
||||
result,
|
||||
*reinterpret_cast<ScaleAType*>(scale_A),
|
||||
*reinterpret_cast<ScaleBType*>(scale_B));
|
||||
}
|
||||
|
||||
*reinterpret_cast<CVecType*>(out) = result;
|
||||
}
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
std::uint32_t WaveTileM,
|
||||
std::uint32_t WaveTileN,
|
||||
std::uint32_t WaveTileK>
|
||||
void MmaSelector_Scale_Real_impl()
|
||||
{
|
||||
using TestType = MmaPipelineTest<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
|
||||
TestType test;
|
||||
const auto should_skip = [](amdgcn_target_id currentArchId) {
|
||||
bool isSupportedWmma = false;
|
||||
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
|
||||
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
|
||||
};
|
||||
const std::function<fp32_t(
|
||||
std::uint32_t, typename TestType::ScaleAType, typename TestType::ScaleBType)>
|
||||
validator =
|
||||
[](std::uint32_t fragK, TestType::ScaleAType scale_A, TestType::ScaleBType scale_B) {
|
||||
fp32_t actual_scale_A = std::powf(2.0f, scale_A - 127.0f);
|
||||
fp32_t actual_scale_B = std::powf(2.0f, scale_B - 127.0f);
|
||||
return static_cast<fp32_t>(fragK) * actual_scale_A * actual_scale_B;
|
||||
};
|
||||
const auto kernel = [](std::uint32_t waveSize,
|
||||
void* a,
|
||||
void* b,
|
||||
void* c,
|
||||
void* out,
|
||||
void* scale_A,
|
||||
void* scale_B) {
|
||||
test_scale_accum_over_k<typename TestType::AType,
|
||||
typename TestType::BType,
|
||||
typename TestType::CType,
|
||||
typename TestType::ScaleAType,
|
||||
typename TestType::ScaleBType,
|
||||
TestType::WaveTileM,
|
||||
TestType::WaveTileN,
|
||||
TestType::WaveTileK>
|
||||
<<<1, waveSize>>>(a, b, c, out, scale_A, scale_B);
|
||||
};
|
||||
test.test_pipeline(should_skip, kernel, validator);
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_16x16x128_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_16x16x128_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_16x16x128_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x64_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
|
||||
}
|
||||
|
||||
// Live test on real hardware for scale selection and execution.
|
||||
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_32x32x64_Real)
|
||||
{
|
||||
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
|
||||
}
|
||||
@@ -3,18 +3,32 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
||||
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
||||
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
// #include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -22,6 +36,9 @@ using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace mma;
|
||||
|
||||
// using F4 = pk_fp4_t;
|
||||
using F8 = fp8_t;
|
||||
using BF8 = bf8_t;
|
||||
using F16 = fp16_t;
|
||||
using F32 = fp32_t;
|
||||
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
|
||||
@@ -80,6 +97,10 @@ struct MmaLayoutTestKernel
|
||||
BVecType b_frag{};
|
||||
CVecType c_frag{};
|
||||
uint32_t sparse_idx{};
|
||||
// The actual scale is computed as pow(2, scale - 127), so:
|
||||
// 125 -> 2^-2 and 129 -> 2^2.
|
||||
int scale_A = 125;
|
||||
int scale_B = 129;
|
||||
static_assert(MmaOp::kCompressionRatio <= 2); // Allow only 4:2 compression (or no).
|
||||
|
||||
// get (m, k, n), where "1" should be placed for this block
|
||||
@@ -97,7 +118,7 @@ struct MmaLayoutTestKernel
|
||||
// direction and we just put our "1" in the k / 2 position (rounded down).
|
||||
if(a_coords[0] == m && a_coords[1] == (k / MmaOp::kCompressionRatio))
|
||||
{
|
||||
a_frag[v] = 1;
|
||||
a_frag[v] = type_convert<typename MmaOp::ADataType>(1.0f);
|
||||
|
||||
// Calc an appropriate sparse idx value for a single 1 in position k. We use a
|
||||
// baseline index of 0x88888888. This sends each compressed index i to
|
||||
@@ -114,7 +135,7 @@ struct MmaLayoutTestKernel
|
||||
auto b_coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
|
||||
if(b_coords[0] == n && b_coords[1] == k)
|
||||
{
|
||||
b_frag[v] = 1;
|
||||
b_frag[v] = type_convert<typename MmaOp::BDataType>(1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +143,10 @@ struct MmaLayoutTestKernel
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, sparse_idx);
|
||||
}
|
||||
else if constexpr(MmaOpTraits<MmaOp>::IsScale)
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, scale_A, scale_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
|
||||
@@ -211,24 +236,30 @@ void run_mma_layout_test()
|
||||
// Lists of intrinsics to test.
|
||||
// clang-format off
|
||||
using Gfx9Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
|
||||
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
|
||||
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
|
||||
>;
|
||||
using Gfx942Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
|
||||
>;
|
||||
using Gfx950Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE> // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
// amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
|
||||
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
|
||||
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
|
||||
// amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
|
||||
>;
|
||||
using Gfx11Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
|
||||
>;
|
||||
using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user