mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4837 (commit 6316035)
[CK TILE] Unification of sparse MFMA/WMMA policy structs (#4837) ## Motivation The existing unification work supports DENSE intrinsics. In this PR we enable support for SPARSE as well as SCALE intrinsics and add an example SPARSE implementation. ## Technical Details Mostly trivial changes. One framework change is that the desired `MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant commit explains, we do not support a fallback family at the moment, but it is something we can consider. ## Test Plan Added a new test for the relevant sparse specializations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6e558658ea
commit
03ce21ddcb
@@ -823,6 +823,12 @@ using enable_if_target_wave64_t =
|
||||
|
||||
#endif // __cplusplus <= 201703L
|
||||
|
||||
template <typename... Ts>
|
||||
constexpr bool all_types_void = std::conjunction_v<std::is_same<void, Ts>...>;
|
||||
|
||||
template <typename... Enablers>
|
||||
using enable_if_all = std::enable_if_t<all_types_void<Enablers...>>;
|
||||
|
||||
} // namespace core::arch
|
||||
|
||||
CK_TILE_HOST bool is_wave32()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
@@ -82,11 +83,13 @@ template <typename ADataType,
|
||||
uint32_t BlockK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
struct amdgcn_mma
|
||||
{
|
||||
// The base instance is unsupported because there is no __builtin to wrap.
|
||||
using OpType = Unsupported;
|
||||
using OpType = Unsupported;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
|
||||
|
||||
// Interface types for A, B, C vectors types
|
||||
using AVecType = ext_vector_t<ADataType, 1>;
|
||||
@@ -122,3 +125,4 @@ struct amdgcn_mma
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "sparse/sparse.hpp"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -68,10 +69,12 @@ struct amdgcn_mma<fp16_t,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = MfmaOp;
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
@@ -125,9 +128,11 @@ struct amdgcn_mma<fp16_t,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Packed register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
|
||||
@@ -52,7 +52,8 @@ struct MfmaDefaultSelector
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
@@ -98,7 +99,8 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
BlockN,
|
||||
1u,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -114,6 +116,7 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -121,7 +124,8 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
@@ -129,7 +133,9 @@ struct MmaDefaultSelector<ADataType,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
|
||||
@@ -57,6 +57,7 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
@@ -69,7 +70,8 @@ template <typename ADataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget>::SelectedOp,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
|
||||
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum MmaOpFamily
|
||||
* @brief Enumeration that defines mma op families and
|
||||
*/
|
||||
enum struct MmaOpFamily
|
||||
{
|
||||
UNDEFINED = 0,
|
||||
DENSE,
|
||||
SPARSE,
|
||||
SCALE,
|
||||
};
|
||||
|
||||
/**
|
||||
* @class is_ctrl_fis_mma_op_of_familylag_of_family
|
||||
* @brief Meta-function to check if MmaOp is of the specified MmaOpFamily
|
||||
* @tparam Family Control flag family
|
||||
* @tparam MmaOp amdgcn struct specialization type
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp, typename = void>
|
||||
struct is_mma_op_of_family : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_of_family
|
||||
* @brief Specialization for Family == MmaOp::OpFamily detection
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
struct is_mma_op_of_family<Family, MmaOp, std::enable_if_t<Family == MmaOp::OpFamily>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_of_family trait
|
||||
* @tparam Family Desired control flag family
|
||||
* @tparam MmaOp The amdgcn struct specialization type to check
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
static constexpr bool is_mma_op_of_family_v = is_mma_op_of_family<Family, MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
@@ -15,11 +16,12 @@ namespace ck_tile::core::arch::mma {
|
||||
* architecture.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Fragment M dimension
|
||||
* @tparam FragN Fragment N dimension
|
||||
* @tparam FragK Fragment K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam Enable SFINAE enabler
|
||||
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
|
||||
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
|
||||
@@ -34,14 +36,22 @@ template <typename ADataType,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily,
|
||||
typename Enable = void>
|
||||
// TODO c++20 requires
|
||||
struct MmaDefaultSelector
|
||||
{
|
||||
// By default, no selection is made, and we fall back to a pass-through unsupported
|
||||
// implementation. This is because we do not have any knowledge of the target architecture here.
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<>>;
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "mfma/mfma_traits.hpp"
|
||||
@@ -69,6 +71,7 @@ concept MmaOpParamsI = requires(MmaOpParams op) {
|
||||
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
|
||||
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
@@ -92,7 +95,8 @@ template <typename ADataType_,
|
||||
uint32_t BlockN_,
|
||||
uint32_t BlockK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_>
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
@@ -101,17 +105,19 @@ struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
BlockN_,
|
||||
BlockK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_>>
|
||||
CompilerTarget_,
|
||||
OpFamily_>>
|
||||
{
|
||||
// Capture incoming template parameters
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
static constexpr auto MmaOpFamily = OpFamily_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
};
|
||||
|
||||
@@ -131,6 +137,8 @@ struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
|
||||
|
||||
// Capture layout parameters
|
||||
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
|
||||
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
|
||||
@@ -144,9 +152,13 @@ struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
|
||||
|
||||
// Additional traits to identify the type of MmaOp at compile time
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsSupported = is_mma_op_supported_v<MmaOp>;
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
|
||||
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
|
||||
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
|
||||
constexpr static bool IsSupported =
|
||||
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseMfmaDefaultSelector
|
||||
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
amdgcn_target_id::GFX942,
|
||||
amdgcn_target_id::GFX950)>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DefaultSparseMfmaCtrlFlags
|
||||
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
|
||||
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
|
||||
*/
|
||||
struct DefaultSparseMfmaCtrlFlags
|
||||
{
|
||||
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept SparseMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for sparse MFMA instructions
|
||||
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
|
||||
*
|
||||
* This specialization implements the SMFMA instruction for fp16_t A and B
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
std::enable_if_t<is_any_value_of(
|
||||
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 8;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum SparseCompressionIndex
|
||||
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
|
||||
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
|
||||
*/
|
||||
enum struct SparseCompressionIndex : int
|
||||
{
|
||||
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
|
||||
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
|
||||
THIRD = 2, // Uses bits [23:16]
|
||||
FOURTH = 3, // Uses bits [31:24]
|
||||
};
|
||||
|
||||
namespace sparse::detail {
|
||||
|
||||
/**
|
||||
* @struct BuiltinParams
|
||||
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
|
||||
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
|
||||
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
|
||||
* (VGPR[srcC][7..0]).
|
||||
*
|
||||
* 8-bit source data:
|
||||
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
|
||||
* (VGPR[srcC][15..0]).
|
||||
*/
|
||||
struct BuiltinParams
|
||||
{
|
||||
int UseFirstIndex; // CBSZ
|
||||
int ByteIndexToOverride; // ABID
|
||||
};
|
||||
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
BuiltinParams params;
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
params.UseFirstIndex = 1;
|
||||
params.ByteIndexToOverride = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
params.UseFirstIndex = 0;
|
||||
params.ByteIndexToOverride = static_cast<int>(Idx);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include sparse MFMA traits and architecture-specific implementations
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsSparse
|
||||
* @brief Implements the default transforms for Sparse
|
||||
*
|
||||
* For 2:4 structured sparsity with inline register metadata:
|
||||
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
|
||||
* - BTransform: Pass-through (sparse operands already formatted)
|
||||
* - CTransform: Pass-through (input accumulator)
|
||||
* - DTransform: Pass-through (output accumulator as-is)
|
||||
*/
|
||||
struct MmaDefaultTransformsSparse
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Specialization for Sparse MFMA transforms
|
||||
* Provides default transform selection for sparse operations
|
||||
*
|
||||
* @tparam MmaOp Sparse MMA operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires(is_mma_op_sparse(MmaOp))
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsSparse;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseWmmaDefaultSelector
|
||||
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseWmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension WMMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
struct DefaultSparseWmmaCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 16;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -77,10 +78,12 @@ struct amdgcn_mma<fp16_t,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
MmaOpFamily::DENSE,
|
||||
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types (duplicated input / b32 accum)
|
||||
using AVecType = ext_vector_t<fp16_t, 16>;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
@@ -37,10 +38,12 @@ struct amdgcn_mma<fp16_t,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
|
||||
@@ -46,7 +46,8 @@ struct WmmaDefaultSelector
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
@@ -91,8 +92,15 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CtrlFlags, CompilerTarget>;
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -108,6 +116,7 @@ struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u,
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -115,7 +124,8 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget>
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
@@ -125,7 +135,9 @@ struct MmaDefaultSelector<ADataType,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_arch_rdna_t<CompilerTarget>>
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
|
||||
@@ -4,8 +4,54 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
* elements into lower part of a_vec to half its effective size.
|
||||
* @param a_vec Vector to be compressed.
|
||||
* @tparam ADataType The data type of a_vec
|
||||
* @tparam CompressedSize The target compression size
|
||||
* @tparam AVec The vector type of a_vec (deduced)
|
||||
* @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields.
|
||||
* Each field encodes the original position (0–3) of the corresponding
|
||||
* non‑zero element in the input. If fewer than CompressedSize
|
||||
* non‑zeros are found, remaining fields default to 2 (see below).
|
||||
*/
|
||||
template <typename ADataType, index_t CompressedSize, typename AVec>
|
||||
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
|
||||
{
|
||||
// idx holds one 2‑bit index per output element (total CompressedSize entries).
|
||||
// It is initialized to the pattern 0b10 for every field. This matches
|
||||
// what the hardware expects when there are fewer than two non‑zero values
|
||||
// in a 4‑element group – the unused output is treated as coming from slot 2.
|
||||
// The loop below will clear and set each field as real non‑zeros are seen.
|
||||
int32_t idx = 0;
|
||||
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); });
|
||||
|
||||
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
// clear the two‑bit field for this output and insert j
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmSmfmacImpl
|
||||
{
|
||||
@@ -41,37 +87,10 @@ struct WarpGemmSmfmacImpl
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
/// elements into lower part of a_vec to half its effective size.
|
||||
///
|
||||
/// @param a_vec Vector to be compressed.
|
||||
///
|
||||
/// @return Four 2-bit indexes of non-zero elements locations
|
||||
///
|
||||
template <typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
|
||||
template <index_t CompressedSize, typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec)
|
||||
{
|
||||
int32_t idx = 0b11101110;
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
return compress_a_impl<ADataType, CompressedSize>(a_vec);
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
@@ -84,10 +103,11 @@ struct WarpGemmSmfmacImpl
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using AVecCompressed =
|
||||
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
static constexpr index_t CompressedSize =
|
||||
ATensor::get_thread_buffer_size() / CompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<ADataType, CompressedSize>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -95,8 +115,9 @@ struct WarpGemmSmfmacImpl
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a(a_vec);
|
||||
const int32_t idx = compress_a_vec<CompressedSize>(a_vec);
|
||||
|
||||
static_assert(CompressedSize == 4);
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp)
|
||||
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
|
||||
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
34
test/ck_tile/core/arch/mma/get_wave_size_helper.hpp
Normal file
34
test/ck_tile/core/arch/mma/get_wave_size_helper.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize)
|
||||
{
|
||||
using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target());
|
||||
|
||||
if(waveSize)
|
||||
*waveSize = static_cast<uint32_t>(CompilerTarget::WAVE_SIZE_ID);
|
||||
}
|
||||
|
||||
static __host__ uint32_t getDeviceWaveSize()
|
||||
{
|
||||
uint32_t* d_wave_size;
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t)));
|
||||
getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
uint32_t wave_size;
|
||||
HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost));
|
||||
return wave_size;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
#include "get_wave_size_helper.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
@@ -47,10 +49,12 @@ struct amdgcn_mma<fp32_t,
|
||||
16u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = DummyOpType;
|
||||
using OpType = DummyOpType;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp32_t, 4>;
|
||||
@@ -81,8 +85,15 @@ struct amdgcn_mma<fp32_t,
|
||||
// Have an alias so we can test supported arch vs unsupported arch
|
||||
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
|
||||
template <typename CompilerTarget>
|
||||
using DummyAmdgcnMma =
|
||||
amdgcn_mma<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCtrlFlags, CompilerTarget>;
|
||||
using DummyAmdgcnMma = amdgcn_mma<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
DummyCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
/*! @struct MmaDefaultSelector
|
||||
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
|
||||
@@ -93,6 +104,7 @@ using DummyAmdgcnMma =
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -100,7 +112,8 @@ template <typename ADataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget>
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
@@ -110,7 +123,9 @@ struct MmaDefaultSelector<ADataType,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_id_dummy_t<CompilerTarget>>
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
using SelectedOp = DummyAmdgcnMma<CompilerTarget>;
|
||||
};
|
||||
@@ -128,6 +143,8 @@ TEST(TestAmdgcnMma, ArchSupported)
|
||||
// Check OpType
|
||||
EXPECT_TRUE(
|
||||
(std::is_same<typename MmaOp::OpType, DummyOpType>::value)); // OpType is DummyOpType
|
||||
// Check OpFamily
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::DENSE, MmaOp>));
|
||||
|
||||
// Check AVecType, BVecType, CVecType
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 4>>::value));
|
||||
@@ -157,6 +174,8 @@ TEST(TestAmdgcnMma, ArchUnsupported)
|
||||
|
||||
// OpType should be Unsupported
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::OpType, Unsupported>::value));
|
||||
// OpFamily should be Undefined
|
||||
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::UNDEFINED, MmaOp>));
|
||||
|
||||
// AVecType, BVecType, CVecType should match default
|
||||
EXPECT_TRUE((std::is_same<typename MmaOp::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
@@ -367,6 +386,7 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
|
||||
EXPECT_TRUE((std::is_same<typename Traits::AVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::BVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_TRUE((std::is_same<typename Traits::CVecType, ext_vector_t<fp32_t, 1>>::value));
|
||||
EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED);
|
||||
EXPECT_EQ(Traits::kAMBlock, 0);
|
||||
EXPECT_EQ(Traits::kBNBlock, 0);
|
||||
EXPECT_EQ(Traits::kAMLane, 0);
|
||||
@@ -386,9 +406,14 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
|
||||
{
|
||||
// Direct selection of the supported dummy instruction
|
||||
using SelectedMma =
|
||||
typename MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, DummyCompilerTarget>::
|
||||
SelectedOp;
|
||||
using SelectedMma = typename MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
// Should select DummyAmdgcnMma specialization
|
||||
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
|
||||
// OpType should be DummyOpType
|
||||
@@ -401,8 +426,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupported)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
|
||||
{
|
||||
// Direct selection of the unsupported dummy instruction
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>::SelectedOp;
|
||||
// OpType should be Unsupported
|
||||
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
// IsSupported should be false
|
||||
@@ -414,9 +445,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
{
|
||||
// Select indirectly with a fragment size of 256x128x64
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 256u, 128u, 64u, DummyCompilerTarget>::
|
||||
SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
256u,
|
||||
128u,
|
||||
64u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
// Should select DummyAmdgcnMma specialization
|
||||
EXPECT_TRUE((std::is_same<SelectedMma, DummyAmdgcnMma<DummyCompilerTarget>>::value));
|
||||
// OpType should be DummyOpType
|
||||
@@ -429,8 +465,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment)
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
|
||||
{
|
||||
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCompilerTarget>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp32_t,
|
||||
fp32_t,
|
||||
fp32_t,
|
||||
8u,
|
||||
8u,
|
||||
8u,
|
||||
DummyCompilerTarget,
|
||||
MmaOpFamily::DENSE>::SelectedOp;
|
||||
EXPECT_FALSE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
}
|
||||
@@ -438,8 +480,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment)
|
||||
// Test MmaDefaultSelector for a different data type (fp16_t) and unsupported arch
|
||||
TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
|
||||
{
|
||||
using SelectedMma =
|
||||
MmaDefaultSelector<fp16_t, fp16_t, fp16_t, 16u, 16u, 16u, amdgcn_target<>>::SelectedOp;
|
||||
using SelectedMma = MmaDefaultSelector<fp16_t,
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>::SelectedOp;
|
||||
// Should select default amdgcn_mma (Unsupported)
|
||||
EXPECT_TRUE((std::is_same<typename SelectedMma::OpType, Unsupported>::value));
|
||||
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
|
||||
@@ -464,7 +512,8 @@ __global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
decltype(get_compiler_target())>;
|
||||
decltype(get_compiler_target()),
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
@@ -561,8 +610,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
@@ -661,8 +711,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
// Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK><<<1, 64>>>(d_a, d_b, d_c, d_out);
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
274
test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp
Normal file
274
test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include "get_wave_size_helper.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
using namespace ck_tile::core::arch;
|
||||
using namespace ck_tile::core::arch::mma;
|
||||
|
||||
using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>());
|
||||
|
||||
TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
|
||||
{
|
||||
// Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32)
|
||||
using TestSparseMfma16x16 = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
|
||||
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
|
||||
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
|
||||
|
||||
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
|
||||
"GFX950 sparse 16x16x32 should be detected as Sparse");
|
||||
|
||||
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, MmaOpTraitsIntegration)
|
||||
{
|
||||
// Create a sparse MMA op (16x16x32 fp16 specialization)
|
||||
using TestSparseMmma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
// Get its traits
|
||||
using TestTraits = MmaOpTraits<TestSparseMmma>;
|
||||
|
||||
// Verify trait detection
|
||||
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
|
||||
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
|
||||
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
|
||||
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
|
||||
|
||||
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, DenseVsSparseDistinction)
|
||||
{
|
||||
// Dense MFMA from mfma/mfma_gfx9.hpp
|
||||
using DenseMfma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
// Sparse MFMA on GFX950
|
||||
using SparseMfma = amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
// Verify they have different operation types
|
||||
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
|
||||
DenseMfma::OpFamily != SparseMfma::OpFamily,
|
||||
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
|
||||
|
||||
// Verify traits correctly identify them
|
||||
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
|
||||
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
|
||||
MmaOpTraits<DenseMfma>::IsSupported,
|
||||
"Dense MFMA should be identified correctly");
|
||||
|
||||
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
|
||||
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
|
||||
MmaOpTraits<SparseMfma>::IsSupported,
|
||||
"Sparse MFMA should be identified correctly");
|
||||
|
||||
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
|
||||
}
|
||||
|
||||
TEST(SparseMMATrait, SparseSelector)
|
||||
{
|
||||
static_for<1, 33, 1>{}([](auto i) {
|
||||
using Selected = typename MmaDefaultSelector<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(2 * i),
|
||||
CompilerTargetGfx950,
|
||||
MmaOpFamily::SPARSE>::SelectedOp;
|
||||
|
||||
static constexpr bool isValid = (i == 16) || (i == 32);
|
||||
if constexpr(isValid)
|
||||
{
|
||||
// Selector should pick a sparse MFMA implementation
|
||||
static_assert(MmaOpTraits<Selected>::IsSparse);
|
||||
static_assert(MmaOpTraits<Selected>::IsMfma);
|
||||
static_assert(MmaOpTraits<Selected>::IsSupported);
|
||||
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Selector should pick the unsupported pass through
|
||||
static_assert(!MmaOpTraits<Selected>::IsSupported);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK>
|
||||
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
|
||||
{
|
||||
using CompilerTarget = decltype(get_compiler_target());
|
||||
using Selector = MmaDefaultSelector<AType,
|
||||
BType,
|
||||
CType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = MmaOpTraits<MmaOp>;
|
||||
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr uint32_t kIters = FragK / MmaTraits::BlockK;
|
||||
|
||||
// Initialize the accumulator
|
||||
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
|
||||
|
||||
// Accumulate input AxB over FragK/BlockK iterations
|
||||
for(uint32_t i = 0; i < kIters; ++i)
|
||||
{
|
||||
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
|
||||
*reinterpret_cast<typename MmaOp::BVecType*>(b),
|
||||
result);
|
||||
}
|
||||
|
||||
*reinterpret_cast<typename MmaOp::CVecType*>(out) = result;
|
||||
}
|
||||
|
||||
// Live test on real hardware for sparse selection and execution.
|
||||
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
|
||||
{
|
||||
int devCount;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
|
||||
|
||||
hipDeviceProp_t devProp;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
|
||||
|
||||
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
|
||||
bool hasDevice = static_cast<bool>(devCount > 0);
|
||||
int deviceWarpSize = devProp.warpSize;
|
||||
|
||||
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
|
||||
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
|
||||
bool isSupportedMfma =
|
||||
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
|
||||
// TODO: c++20 add check for arch id
|
||||
if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) ||
|
||||
!(isSupportedWmma || isSupportedMfma))
|
||||
{
|
||||
GTEST_SKIP() << "No HIP device found. Skipping test.";
|
||||
}
|
||||
|
||||
using AType = fp16_t;
|
||||
using BType = fp16_t;
|
||||
using CType = fp32_t;
|
||||
|
||||
// Fragment size, also the expected block size from the selector.
|
||||
// Note: Actual blockK might be slightly different due to hardware implementation, but the
|
||||
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
|
||||
// correct.
|
||||
static constexpr uint32_t FragM = 16;
|
||||
static constexpr uint32_t FragN = 16;
|
||||
static constexpr uint32_t FragK = 32;
|
||||
static constexpr uint32_t BlockM = FragM;
|
||||
static constexpr uint32_t BlockN = FragN;
|
||||
static constexpr uint32_t BlockK = FragK;
|
||||
|
||||
// The number of elements per thread
|
||||
uint32_t AElements = BlockM * BlockK / deviceWarpSize;
|
||||
uint32_t BElements = BlockN * BlockK / deviceWarpSize;
|
||||
uint32_t CElements = BlockM * BlockN / deviceWarpSize;
|
||||
|
||||
uint32_t ASize = AElements * sizeof(AType);
|
||||
uint32_t BSize = BElements * sizeof(BType);
|
||||
uint32_t CSize = CElements * sizeof(CType);
|
||||
|
||||
// Initialize A and B to all 1's, C to all 0's
|
||||
std::vector<AType> h_a(AElements, static_cast<AType>(1));
|
||||
std::vector<BType> h_b(BElements, static_cast<BType>(1));
|
||||
std::vector<CType> h_c(CElements, static_cast<CType>(0));
|
||||
std::vector<CType> h_out(CElements, static_cast<CType>(0));
|
||||
|
||||
AType* d_a;
|
||||
BType* d_b;
|
||||
CType* d_c;
|
||||
CType* d_out;
|
||||
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
|
||||
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
|
||||
|
||||
// Copy inputs to device
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
|
||||
|
||||
const auto wave_size = getDeviceWaveSize();
|
||||
test_sparse_accum_over_k<AType, BType, CType, FragM, FragN, FragK>
|
||||
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
|
||||
|
||||
// Output should be FragK for all elements, because the inputs are all 1's
|
||||
for(size_t i = 0; i < CElements; ++i)
|
||||
{
|
||||
// In sparse only half of the A values are non-zero, thus the /2.
|
||||
CType expected = static_cast<CType>(FragK) / 2;
|
||||
|
||||
EXPECT_NEAR(h_out[i], expected, 1e-3);
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipFree(d_a));
|
||||
HIP_CHECK_ERROR(hipFree(d_b));
|
||||
HIP_CHECK_ERROR(hipFree(d_c));
|
||||
HIP_CHECK_ERROR(hipFree(d_out));
|
||||
}
|
||||
Reference in New Issue
Block a user