[CK TILE] Unification of Scale MFMA/WMMA Policy Structs (#5857)

## Motivation

The existing unification work supports DENSE and SPARSE intrinsics. In
this PR, we enable support for SCALE intrinsics and add example SCALE
implementations.

## Technical Details

Adding MFMA SCALE intrinsics support, adding tests for MFMA SCALE
intrinsics, and adding WMMA SCALE policy trait.

Note: fp6 SCALE intrinsics support is not included in this PR, as its
handling in ck_tile is currently more specialized and does not follow
the same pattern as other datatypes.

## Test Plan

Added new tests for the relevant SCALE specialisations.

## Test Result

Test should pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yung-sheng Tu
2026-04-20 16:28:23 +02:00
committed by GitHub
parent b65d734c87
commit 21acf3ba3a
14 changed files with 1116 additions and 42 deletions

View File

@@ -25,6 +25,13 @@
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale.hpp"
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"

View File

@@ -245,7 +245,7 @@ concept MmaOpI = requires(MmaOp op) {
{ MmaOp::kCMPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMNumAccess } -> std::convertible_to<unsigned int>;
{ MmaOp::kCompressionRatio } -> std::convertible_to<unsigned int>;
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int>);
} && (HasExecSignature<MmaOp> || HasExecSignature<MmaOp, int> || HasExecSignature<MmaOp, int, int>);
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -303,6 +303,8 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
#pragma clang diagnostic pop
// Include the implementations
#include "wmma/wmma.hpp"
#include "wmma/wmma.hpp" // should be included before the below headers
#include "mfma/mfma.hpp"
#include "scale/scale.hpp"
#include "sparse/sparse.hpp"

View File

@@ -207,11 +207,17 @@ struct MmaPipelineBase
/**
* @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C.
* @return A @c std::tuple of the transformed (A, B, C) vectors ready for the mma loop.
* @return A @c std::tuple of the transformed (A, B, C, [scaleA, scaleB]) vectors ready for the
* mma loop.
*/
template <typename ATransformInputs, typename BTransformInputs, typename CTransformInputs>
CK_TILE_DEVICE static decltype(auto)
applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum)
template <typename ATransformInputs,
typename BTransformInputs,
typename CTransformInputs,
typename... ExtraArgs>
CK_TILE_DEVICE static decltype(auto) applyTransformsToInputs(ATransformInputs&& a,
BTransformInputs&& b,
CTransformInputs&& accum,
ExtraArgs&&... extras)
{
using InternalAVecT = typename Derived::InternalAVecT;
using InternalBVecT = typename Derived::InternalBVecT;
@@ -224,19 +230,18 @@ struct MmaPipelineBase
return std::make_tuple(
preApplyTransform<InternalAVecT, ATransform>(std::forward<ATransformInputs>(a)),
preApplyTransform<InternalBVecT, BTransform>(std::forward<BTransformInputs>(b)),
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)));
preApplyTransform<InternalCVecT, CTransform>(std::forward<CTransformInputs>(accum)),
std::forward<ExtraArgs>(extras)...);
}
/**
* @brief Apply the post-transform and buffer formatting to the C (accumulator) output.
* @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed.
* @param c_result The accumulator to post-process.
* @return The final D output in the user-facing vector type.
*/
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static auto
applyTransformToOutput(std::tuple<ATransformResult, BTransformResult, CTransformResult>&& vecs)
template <typename CTransformResult>
CK_TILE_DEVICE static auto applyTransformToOutput(CTransformResult&& c_result)
{
auto&& [a_result, b_result, c_result] = vecs;
static_assert(!is_std_tuple_v<decltype(c_result)>,
"If CTransform returns more than the vector, update this function.");
@@ -270,7 +275,46 @@ struct MmaPipelineBase
Derived::execImpl(transformed_inputs);
return applyTransformToOutput(std::move(transformed_inputs));
auto&& [a_result, b_result, c_result] = std::move(transformed_inputs);
return applyTransformToOutput(std::move(c_result));
}
else
{
// Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp)
// Code should not reach here, but HOST/DEVICE compile passes are
// weirdly intertwined and instead of having constexpr in the calling
// site (tests) we do this. See also changes by this commit.
return Derived::MmaOp::exec({}, {}, {});
}
}
template <typename VecTA,
typename VecTB,
typename VecTC,
typename ScaleADataType,
typename ScaleBDataType>
CK_TILE_DEVICE static decltype(auto)
exec(VecTA&& a, VecTB&& b, VecTC&& accum, ScaleADataType&& scale_A, ScaleBDataType&& scale_B)
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
Derived::execImpl(transformed_inputs);
auto&& [a_result, b_result, c_result, scale_A_result, scale_B_result] =
std::move(transformed_inputs);
return applyTransformToOutput(std::move(c_result));
}
else
{

View File

@@ -0,0 +1,229 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u, 64u, 32, 2, 1, 2, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 16x16x128 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u, 64u, 32, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for fp8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
scale::detail::ScaleDataTypeToFlag_v<fp8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for bf8_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u, 64u, 32, 2, 1, 2, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
scale::detail::ScaleDataTypeToFlag_v<bf8_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets
*
* This specialization implements the Scale MFMA instruction for pk_fp4_t A and B
* matrices with fp32_t accumulator, with 32x32x64 block sizes.
*
* @tparam CtrlFlags Control flags for the Scale MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsScaleMfmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, CtrlFlags, CompilerTarget, MmaOpFamily::SCALE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
: amdgcn_mma_base<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u, 64u, 32, 1, 1, 1, 1, 16, 4, MfmaOp, MmaOpFamily::SCALE>
// clang-format on
{
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B)
{
return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<int32x8_t>(aVec),
bit_cast<int32x8_t>(bVec),
cVec,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
scale::detail::ScaleDataTypeToFlag_v<pk_fp4_t>,
static_cast<int>(CtrlFlags::OPSEL_A),
scale_A,
static_cast<int>(CtrlFlags::OPSEL_B),
scale_B)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,149 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
/**
* @class ScaleMfmaDefaultSelector
* @brief Implements a default scale MFMA selector strategy. The SelectedOp can be unsupported.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension
* @tparam WaveTileN Size of the N dimension
* @tparam WaveTileKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) &&
// is_power_of_two_integer(WaveTileKTest))
struct ScaleMfmaDefaultSelector
{
private:
// Define our candidate MFMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileKTest,
DefaultScaleMfmaCtrlFlags,
CompilerTarget,
MmaOpFamily::SCALE>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
CandidateOp,
amdgcn_mma<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileKTest,
void,
amdgcn_target<>,
MmaOpFamily::UNDEFINED>>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the CDNA default MMA selector strategy for scale MFMA.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
* @tparam CompilerTarget The compiler target
* @tparam OpFamily The MMA operation family
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
typename CompilerTarget,
MmaOpFamily OpFamily>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
OpFamily,
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
amdgcn_target_id::GFX950)>,
std::enable_if_t<OpFamily == MmaOpFamily::SCALE>>>
{
private:
// Provide the default depth-K search strategy for each class of common MFMA shapes.
// Start searching from the largest K dimension MFMA shape down to the smallest.
using CandidateOp16x16 = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
128u,
CompilerTarget>::SelectedOp;
using CandidateOp32x32 = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
32u,
32u,
64u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp = typename ScaleMfmaDefaultSelector<ADataType,
BDataType,
CDataType,
1u,
1u,
1u,
CompilerTarget>::SelectedOp;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
static constexpr bool IsSupported16x16 =
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
static constexpr bool IsSupported32x32 =
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
using SelectedOp =
std::conditional_t<IsSupported32x32,
CandidateOp32x32,
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,10 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// Include scale MFMA traits and architecture-specific implementations
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"

View File

@@ -0,0 +1,77 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/config.hpp"
#include <cstdint>
#include <tuple>
#include <type_traits>
#include <utility>
namespace ck_tile::core::arch::mma {
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t FragM,
std::uint32_t FragN,
std::uint32_t FragK,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp_ =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp_ = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
MmaOpFamily::SCALE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Expose caller-side vector types
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
template <typename VecTA,
typename VecTB,
typename VecTC,
typename ScaleADataType,
typename ScaleBDataType>
CK_TILE_DEVICE static void
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
{
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
// #include "ck_tile/core/numeric/pk_fp6.hpp"
#include <cstdint>
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
#include <concepts>
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
namespace ck_tile::core::arch::mma {
namespace scale::detail {
template <typename T>
struct ScaleDataTypeToFlag;
template <>
struct ScaleDataTypeToFlag<fp8_t> // e4m3
{
static constexpr std::int32_t value = 0;
};
template <>
struct ScaleDataTypeToFlag<bf8_t> // e5m2
{
static constexpr std::int32_t value = 1;
};
// template <>
// struct ScaleDataTypeToFlag<pk_fp6_t<1>> // e2m3
// {
// static constexpr std::int32_t value = 2;
// };
// template <>
// struct ScaleDataTypeToFlag<bf6_t> // e3m2
// {
// static constexpr std::int32_t value = 3;
// };
template <>
struct ScaleDataTypeToFlag<pk_fp4_t> // e2m1
{
static constexpr std::int32_t value = 4;
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept ScaleMfmaDataTypeToFlag
* @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9
*/
template <typename DataTypeToFlag>
concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) {
// Flag members for scale MFMA instructions
{ DataTypeToFlag::value } -> std::convertible_to<int>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
template <typename T>
inline constexpr std::int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag<T>::value;
} // namespace scale::detail
struct DefaultScaleMfmaCtrlFlags
{
static constexpr std::int32_t OPSEL_A = 0;
static constexpr std::int32_t OPSEL_B = 0;
};
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
/**
* @concept ScaleMfmaCtrlFlags
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
*/
template <typename CtrlFlags>
concept ScaleMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
// Flag members for scale MFMA instructions
{ CtrlFlags::OPSEL_A } -> std::convertible_to<int>;
{ CtrlFlags::OPSEL_B } -> std::convertible_to<int>;
};
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include <type_traits>
namespace ck_tile::core::arch::mma {
/**
* @struct MmaDefaultTransformsScale
* @brief Implements the default MMA transforms for Scale
*/
struct MmaDefaultTransformsScale
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @struct MmaTransformsDefaultSelector
* @brief Specialization for Scale MFMA transforms
* Provides default transform selection for scale operations
*
* @tparam MmaOp Scale MMA operation
* @tparam CompilerTarget The compiler target
*/
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_mma_op_scale(MmaOp))
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SCALE>>
{
using SelectedTransforms = MmaDefaultTransformsScale;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -11,6 +11,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx12")
add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -10,23 +10,27 @@
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include "../get_wave_size_helper.hpp"
template <typename AType_ = ck_tile::fp16_t,
typename BType_ = ck_tile::fp16_t,
typename CType_ = ck_tile::fp32_t,
uint32_t WaveTileM_ = 16,
uint32_t WaveTileN_ = 16,
uint32_t WaveTileK_ = 32>
template <typename AType_ = ck_tile::fp16_t,
typename BType_ = ck_tile::fp16_t,
typename CType_ = ck_tile::fp32_t,
uint32_t WaveTileM_ = 16,
uint32_t WaveTileN_ = 16,
uint32_t WaveTileK_ = 32,
typename ScaleAType_ = int,
typename ScaleBType_ = int>
struct MmaPipelineTest
{
using AType = AType_;
using BType = BType_;
using CType = CType_;
using ScaleAType = ScaleAType_;
using ScaleBType = ScaleBType_;
static constexpr auto WaveTileM = WaveTileM_;
static constexpr auto WaveTileN = WaveTileN_;
static constexpr auto WaveTileK = WaveTileK_;
@@ -120,4 +124,109 @@ struct MmaPipelineTest
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
}
void
test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t, ScaleAType, ScaleBType)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits<AType>::PackedSize;
uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits<BType>::PackedSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
uint32_t ScaleASize = 1 * sizeof(ScaleAType);
uint32_t ScaleBSize = 1 * sizeof(ScaleBType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
}
else
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1.0f));
}
std::vector<BType> h_b(BElements, type_convert<BType>(1.0f));
std::vector<CType> h_c(CElements, type_convert<CType>(0.0f));
std::vector<CType> h_out(CElements, type_convert<CType>(0.0f));
// The actual scale is computed as pow(2, scale - 127), so:
// 126 -> 2^-1 and 129 -> 2^2.
ScaleAType h_scale_a = 126;
ScaleBType h_scale_b = 129;
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
ScaleAType* d_scale_a;
ScaleBType* d_scale_b;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
{
EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
HIP_CHECK_ERROR(hipFree(d_scale_a));
HIP_CHECK_ERROR(hipFree(d_scale_b));
}
};

View File

@@ -0,0 +1,270 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "pipeline_tests_helper.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <type_traits>
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
void ScaleMfmaGfx950Specialization_impl()
{
using TestScaleMma = amdgcn_mma<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
TestScaleMma::OpFamily == MmaOpFamily::SCALE,
"GFX950 scale intrinsic should have ScaleMFMAOp type");
static_assert(is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>,
"GFX950 scale intrinsic should be detected as Scale");
// Get its traits
using TestTraits = MmaOpTraits<TestScaleMma>;
// Verify trait detection
static_assert(TestTraits::IsScale, "Scale MMA should be detected as scale");
static_assert(TestTraits::IsSupported, "Scale MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Scale MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Scale MFMA should not be detected as WMMA");
}
TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
{
// Test fp8 → fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
// Test bf8 → fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
// Test fp4 → fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
// Test fp8 → fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
// Test bf8 → fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
// Test fp4 → fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
}
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
void TestConceptRequirements_impl()
{
using TestScaleMma = amdgcn_mma<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(MmaOpI<TestScaleMma>);
}
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
TEST(ScaleMMATrait, TestConceptRequirements)
{
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
}
template <typename AType, typename BType, typename CType>
void ScaleSelector_impl()
{
static_for<2, 14, 6>{}([](auto k_factor) {
static_for<1, 33, 1>{}([&](auto i) {
using Selected = typename MmaDefaultSelector<AType,
BType,
CType,
static_cast<std::uint32_t>(i),
static_cast<std::uint32_t>(i),
static_cast<std::uint32_t>(k_factor * i),
CompilerTargetGfx950,
MmaOpFamily::SCALE>::SelectedOp;
static constexpr bool isValid = (i == 16 && k_factor == 8) || (i == 32);
if constexpr(isValid)
{
// Selector should pick a scale MFMA implementation
static_assert(MmaOpTraits<Selected>::IsScale);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
}
});
});
}
TEST(ScaleMMATrait, ScaleSelector)
{
ScaleSelector_impl<fp8_t, fp8_t, fp32_t>();
ScaleSelector_impl<bf8_t, bf8_t, fp32_t>();
ScaleSelector_impl<pk_fp4_t, pk_fp4_t, fp32_t>();
}
template <typename AType,
typename BType,
typename CType,
typename ScaleAType,
typename ScaleBType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
__global__ void
test_scale_accum_over_k(void* a, void* b, void* c, void* out, void* scale_A, void* scale_B)
{
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
// NOTE: WaveTileK is used as a Pipeline template parameter, but the K iteration is
// happening outside the Pipeline. This is a bit incorrect currently.
static constexpr std::uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(std::uint32_t i = 0; i < kIters; ++i)
{
result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
result,
*reinterpret_cast<ScaleAType*>(scale_A),
*reinterpret_cast<ScaleBType*>(scale_B));
}
*reinterpret_cast<CVecType*>(out) = result;
}
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
void MmaSelector_Scale_Real_impl()
{
using TestType = MmaPipelineTest<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
TestType test;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = false;
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(
std::uint32_t, typename TestType::ScaleAType, typename TestType::ScaleBType)>
validator =
[](std::uint32_t fragK, TestType::ScaleAType scale_A, TestType::ScaleBType scale_B) {
fp32_t actual_scale_A = std::powf(2.0f, scale_A - 127.0f);
fp32_t actual_scale_B = std::powf(2.0f, scale_B - 127.0f);
return static_cast<fp32_t>(fragK) * actual_scale_A * actual_scale_B;
};
const auto kernel = [](std::uint32_t waveSize,
void* a,
void* b,
void* c,
void* out,
void* scale_A,
void* scale_B) {
test_scale_accum_over_k<typename TestType::AType,
typename TestType::BType,
typename TestType::CType,
typename TestType::ScaleAType,
typename TestType::ScaleBType,
TestType::WaveTileM,
TestType::WaveTileN,
TestType::WaveTileK>
<<<1, waveSize>>>(a, b, c, out, scale_A, scale_B);
};
test.test_pipeline(should_skip, kernel, validator);
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_16x16x128_Real)
{
MmaSelector_Scale_Real_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_16x16x128_Real)
{
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_16x16x128_Real)
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x64_Real)
{
MmaSelector_Scale_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
{
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_32x32x64_Real)
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
}

View File

@@ -3,18 +3,32 @@
#pragma once
#include <hip/hip_runtime.h>
#include <gtest/gtest.h>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
// #include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_config.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
#include <cstdio>
#include <vector>
#include <cmath>
#include <cstdint>
#include <vector>
namespace {
@@ -22,6 +36,9 @@ using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace mma;
// using F4 = pk_fp4_t;
using F8 = fp8_t;
using BF8 = bf8_t;
using F16 = fp16_t;
using F32 = fp32_t;
using Target908 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>());
@@ -80,6 +97,10 @@ struct MmaLayoutTestKernel
BVecType b_frag{};
CVecType c_frag{};
uint32_t sparse_idx{};
// The actual scale is computed as pow(2, scale - 127), so:
// 125 -> 2^-2 and 129 -> 2^2.
int scale_A = 125;
int scale_B = 129;
static_assert(MmaOp::kCompressionRatio <= 2); // Allow only 4:2 compression (or no).
// get (m, k, n), where "1" should be placed for this block
@@ -97,7 +118,7 @@ struct MmaLayoutTestKernel
// direction and we just put our "1" in the k / 2 position (rounded down).
if(a_coords[0] == m && a_coords[1] == (k / MmaOp::kCompressionRatio))
{
a_frag[v] = 1;
a_frag[v] = type_convert<typename MmaOp::ADataType>(1.0f);
// Calc an appropriate sparse idx value for a single 1 in position k. We use a
// baseline index of 0x88888888. This sends each compressed index i to
@@ -114,7 +135,7 @@ struct MmaLayoutTestKernel
auto b_coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
if(b_coords[0] == n && b_coords[1] == k)
{
b_frag[v] = 1;
b_frag[v] = type_convert<typename MmaOp::BDataType>(1.0f);
}
}
@@ -122,6 +143,10 @@ struct MmaLayoutTestKernel
{
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, sparse_idx);
}
else if constexpr(MmaOpTraits<MmaOp>::IsScale)
{
c_frag = MmaOp::exec(a_frag, b_frag, c_frag, scale_A, scale_B);
}
else
{
c_frag = MmaOp::exec(a_frag, b_frag, c_frag);
@@ -211,24 +236,30 @@ void run_mma_layout_test()
// Lists of intrinsics to test.
// clang-format off
using Gfx9Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_16x16x16f16
amdgcn_mma<F16, F16, F32, 64u, 32u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 32u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_32x32x4f16
amdgcn_mma<F16, F16, F32, 64u, 4u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE>, // mfma_f32_4x4x4f16
amdgcn_mma<F16, F16, F32, 4u, 64u, 4u, DefaultMfmaCtrlFlags, Target908, MmaOpFamily::DENSE> // mfma_f32_4x4x4f16
>;
using Gfx942Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target942, MmaOpFamily::SPARSE> // smfmac_f32_16x16x32_f16
>;
using Gfx950Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE> // mfma_f32_16x16x32_f16
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultMfmaCtrlFlags, Target950, MmaOpFamily::DENSE>, // mfma_f32_16x16x32_f16
amdgcn_mma<F8, F8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
// amdgcn_mma<F4, F4, F32, 16u, 16u, 128u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_16x16x128_f8f6f4
amdgcn_mma<F8, F8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE>, // mfma_scale_f32_32x32x64_f8f6f4
amdgcn_mma<BF8, BF8, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
// amdgcn_mma<F4, F4, F32, 32u, 32u, 64u, DefaultScaleMfmaCtrlFlags, Target950, MmaOpFamily::SCALE> // mfma_scale_f32_32x32x64_f8f6f4
>;
using Gfx11Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32
>;
using Gfx12Intrinsics = ::testing::Types<
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
amdgcn_mma<F16, F16, F32, 16u, 16u, 16u, DefaultWmmaCtrlFlags<F16, F16, F32>, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultSparseMfmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_f32_16x16x32_f16_w32
>;
// clang-format on