From 10eb15416c161baec0591529441d9267b450ac8e Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:39:59 -0700 Subject: [PATCH] First look at mfma / wmma unification (#2704) * First look at mfma / wmma unification * Refactor * Re-org file structure * Restructure transform selection and WaveWiseMma class * Update license files. Add missing gfx1151 support. Change wave size for HOST to 1. Update datatypes naming consistency * Fixes default MmaSelector implentation * Adds unit tests for amdgcn_mma and arch * Consolidate common arch id checks to constexpr functions. Strongly type ids as amdgcn_target_arch_id object. * Refactor is_any_value_of * Fixes mma_selector logic * Fix typo * Add mma selector test for tile decomposition * Fix compilation of mma.hpp * Revert back to c++17 compatibility * Fix compiler error by returning index_t from get_warp_size() * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fixes compiler error for missing is_wave32() function * Fixes compiler error for host wave_size() should be 64 * Fixes compiler errors where __cpp_concepts is not defined * Fixes compiler errors where __cpp_concepts is not defined * Fix test failure for host is wave64 by default --------- Co-authored-by: Chris Millette Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: b9c6cb1452b25f06cea1f00d6c3d168334d65035] --- include/ck_tile/core.hpp | 16 + include/ck_tile/core/arch/arch.hpp | 762 +++++++++++++++++- include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 118 +++ include/ck_tile/core/arch/mma/mfma/mfma.hpp | 10 + .../ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp | 162 ++++ .../core/arch/mma/mfma/mfma_selector.hpp | 189 +++++ .../core/arch/mma/mfma/mfma_traits.hpp | 44 + .../core/arch/mma/mfma/mfma_transforms.hpp | 38 + include/ck_tile/core/arch/mma/mma.hpp | 234 ++++++ .../ck_tile/core/arch/mma/mma_selector.hpp | 63 ++ include/ck_tile/core/arch/mma/mma_traits.hpp | 151 ++++ .../ck_tile/core/arch/mma/mma_transforms.hpp | 48 ++ include/ck_tile/core/arch/mma/wmma/wmma.hpp | 34 + .../ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp | 109 +++ .../ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp | 69 ++ .../core/arch/mma/wmma/wmma_selector.hpp | 161 ++++ .../core/arch/mma/wmma/wmma_traits.hpp | 44 + .../core/arch/mma/wmma/wmma_transforms.hpp | 112 +++ include/ck_tile/core/config.hpp | 239 ++++++ include/ck_tile/core/utility/ignore.hpp | 8 +- include/ck_tile/core/utility/type_traits.hpp | 21 +- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/core/CMakeLists.txt | 1 + test/ck_tile/core/arch/CMakeLists.txt | 13 + test/ck_tile/core/arch/mma/CMakeLists.txt | 12 + .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 682 ++++++++++++++++ test/ck_tile/core/arch/test_arch.cpp | 396 +++++++++ 27 files changed, 3726 insertions(+), 11 deletions(-) create mode 100644 include/ck_tile/core/arch/mma/amdgcn_mma.hpp create mode 100644 include/ck_tile/core/arch/mma/mfma/mfma.hpp create mode 100644 include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp create mode 100644 include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp create mode 100644 include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp create mode 100644 include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp create mode 100644 include/ck_tile/core/arch/mma/mma.hpp create mode 100644 include/ck_tile/core/arch/mma/mma_selector.hpp create mode 100644 include/ck_tile/core/arch/mma/mma_traits.hpp create mode 100644 include/ck_tile/core/arch/mma/mma_transforms.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp create mode 100644 include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp create mode 100644 test/ck_tile/core/CMakeLists.txt create mode 100644 test/ck_tile/core/arch/CMakeLists.txt create mode 100644 test/ck_tile/core/arch/mma/CMakeLists.txt create mode 100644 test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp create mode 100644 test/ck_tile/core/arch/test_arch.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 57ef8705c4..5c05e9b6ee 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -13,6 +13,22 @@ #include "ck_tile/core/arch/amd_transpose_load_encoding.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_selector.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_transforms.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/arch/workgroup_barrier.hpp" #include "ck_tile/core/config.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index b66c00e392..70338e1185 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1,5 +1,5 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once @@ -9,6 +9,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/utility/ignore.hpp" @@ -60,15 +61,753 @@ enum struct memory_operation_enum : std::uint16_t add }; -CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() +namespace core::arch { + +/** + * @enum amdgcn_target_id + * @brief Defines constants for AMDGCN architecture target IDs + */ +enum struct amdgcn_target_id { -#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) - return 64; -#else - return 32; -#endif + GFX908 = 0x0908, // MI-100... + GFX90A = 0x090A, + GFX942 = 0x0942, + GFX950 = 0x0950, + GFX1030 = 0x1030, + GFX1031 = 0x1031, + GFX1032 = 0x1032, + GFX1034 = 0x1034, + GFX1035 = 0x1035, + GFX1036 = 0x1036, + GFX103_GENERIC = 0x103F, + GFX1100 = 0x1100, + GFX1101 = 0x1101, + GFX1102 = 0x1102, + GFX1103 = 0x1103, + GFX1150 = 0x1150, + GFX1151 = 0x1151, + GFX1152 = 0x1152, + GFX11_GENERIC = 0x11FF, + GFX1200 = 0x1200, + GFX1201 = 0x1201, + GFX12_GENERIC = 0x12FF, + HOST = 0x0000, +}; + +enum struct amdgcn_target_family_id +{ + GFX9 = 0x09, + GFX10_3 = 0x10, + GFX11 = 0x11, + GFX12 = 0x12, + HOST = 0x00, +}; + +enum struct amdgcn_target_arch_id +{ + CDNA = 0x01, + RDNA = 0x02, + HOST = 0x00, +}; + +enum struct amdgcn_target_wave_size_id +{ + WAVE32 = 32u, + WAVE64 = 64u, + HOST = 64u, // TODO: Is this correct? Should the host default to 64 or 1? +}; + +#if 1 //__cplusplus <= 201703L + +template +struct amdgcn_target +{ + static constexpr amdgcn_target_id TARGET_ID = TargetId; + static constexpr amdgcn_target_family_id FAMILY_ID = FamilyId; + static constexpr amdgcn_target_arch_id ARCH_ID = ArchId; + static constexpr amdgcn_target_wave_size_id WAVE_SIZE_ID = WaveSizeId; +}; + +template +static constexpr auto make_amdgcn_gfx9_target() +{ + return amdgcn_target{}; } +template +static constexpr auto make_amdgcn_gfx10_3_target() +{ + return amdgcn_target{}; +} + +template +static constexpr auto make_amdgcn_gfx11_target() +{ + return amdgcn_target{}; +} + +template +static constexpr auto make_amdgcn_gfx12_target() +{ + return amdgcn_target{}; +} + +template +static constexpr auto is_target_id_any_of() +{ + return is_any_value_of(CompilerTarget::TARGET_ID, TargetIds...); +} + +template +static constexpr auto is_target_family_any_of() +{ + return is_any_value_of(CompilerTarget::FAMILY_ID, FamilyIds...); +} + +template +static constexpr bool is_target_family_gfx9() +{ + return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX9; +} + +template +static constexpr bool is_target_family_gfx10_3() +{ + return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX10_3; +} + +template +static constexpr bool is_target_family_gfx11() +{ + return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX11; +} + +template +static constexpr bool is_target_family_gfx12() +{ + return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX12; +} + +template +static constexpr bool is_target_arch_cdna() +{ + return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::CDNA; +} + +template +static constexpr bool is_target_arch_rdna() +{ + return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::RDNA; +} + +template +static constexpr bool is_target_wave_size_32() +{ + return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32; +} + +template +static constexpr bool is_target_wave_size_64() +{ + return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64; +} + +// Helper to map compiler state to target arch id + +#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx9_target(); \ + } \ + else + +#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx10_3_target(); \ + } \ + else + +#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx11_target(); \ + } \ + else + +#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx12_target(); \ + } \ + else + +/** + * @brief Returns the amdgcn_target of the current compiler pass. + * @note This is where we tie the compiler state to our internal target architecture representation + * at compile time. + */ +constexpr auto get_compiler_target() +{ + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC); + + // Return HOST by default + if constexpr(amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE) + { + return amdgcn_target<>{}; + } +} + +// Cleanup +#undef MAP_COMPILER_STATE_TO_GFX9_TARGET +#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET +#undef MAP_COMPILER_STATE_TO_GFX11_TARGET +#undef MAP_COMPILER_STATE_TO_GFX12_TARGET + +// Sanity check: device compile must have a valid target architecture +static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE || + get_compiler_target().TARGET_ID != amdgcn_target_id::HOST, + "Device compile must have a valid target device architecture"); + +// Sanity check: host compile must have HOST target architecture +static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE || + get_compiler_target().TARGET_ID == amdgcn_target_id::HOST, + "Host compile must target HOST architecture"); + +// TODO: c++20 use the make functions and constexpr if to avoid string construction and find at +// runtime +#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID(NAME_STRING, TARGET_ID) \ + if(str.find(NAME_STRING) != std::string::npos) \ + { \ + return amdgcn_target_id::TARGET_ID; \ + } \ + else + +/** + * @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value. + * Returns amdgcn_target_arch_id::HOST if no match is found. + * Matches if the input contains the architecture substring. + * Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info. + */ +// TODO: c++20 constexpr if and string_view to avoid std::string construction and find at runtime +// TODO: c++20 return amdgcn_target instance instead of just the target id +CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* testStr) +{ + auto str = std::string(testStr); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx908", GFX908); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx90a", GFX90A); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx942", GFX942); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx950", GFX950); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx10_3_generic", GFX103_GENERIC); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1100", GFX1100); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1101", GFX1101); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1102", GFX1102); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1103", GFX1103); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx12_generic", GFX12_GENERIC); + + // Default case: return HOST target if no match is found + return amdgcn_target_id::HOST; +} + +#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID + +/** + * @brief SFINAE enabler for a compiler target if the target id is in the list of supported target + * ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908 + */ +template +using enable_if_target_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the family id is in the list of supported family + * ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetFamilyIds The list of supported family ids, e.g., + * amdgcn_target_family_id::GFX9 + */ +template +using enable_if_target_family_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA + */ +template +using enable_if_target_arch_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave + * size ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g., + * amdgcn_target_wave_size_id::WAVE64 + */ +template +using enable_if_target_wave_size_id_t = + std::enable_if_t; + +/// Specialized enablers for common families, architectures, and wave sizes /// + +/** + * @brief SFINAE enabler for GFX9 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx9_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX10.3 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx10_3_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX11 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx11_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX12 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx12_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for CDNA architecture targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_arch_cdna_t = + enable_if_target_arch_id_t; + +/** + * @brief SFINAE enabler for RDNA architecture targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_arch_rdna_t = + enable_if_target_arch_id_t; + +/** + * @brief SFINAE enabler for WAVE32 size targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_wave32_t = + enable_if_target_wave_size_id_t; + +/** + * @brief SFINAE enabler for WAVE64 size targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_wave64_t = + enable_if_target_wave_size_id_t; + +#elif __cplusplus >= 202002L + +struct amdgcn_target +{ + // Target architecture identifiers + // These are set to HOST (0) by default + // TARGET_ID is the specific architecture id (e.g., GFX908) + // FAMILY_ID is the architecture family id (e.g., GFX9) + // ARCH_ID is the architecture class id (e.g., CDNA, RDNA) + // WAVE_SIZE_ID is the wavefront size id (e.g., WAVE32, WAVE64) + const amdgcn_target_id TARGET_ID = amdgcn_target_id::HOST; + const amdgcn_target_family_id FAMILY_ID = amdgcn_target_family_id::HOST; + const amdgcn_target_arch_id ARCH_ID = amdgcn_target_arch_id::HOST; + const amdgcn_target_wave_size_id WAVE_SIZE_ID = amdgcn_target_wave_size_id::HOST; +}; + +static constexpr auto make_amdgcn_gfx10_3_target(amdgcn_target_id targetId) +{ + return amdgcn_target{.TARGET_ID = targetId, + .FAMILY_ID = amdgcn_target_family_id::GFX10_3, + .ARCH_ID = amdgcn_target_arch_id::RDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; +} + +static constexpr auto make_amdgcn_gfx9_target(amdgcn_target_id targetId) +{ + return amdgcn_target{.TARGET_ID = targetId, + .FAMILY_ID = amdgcn_target_family_id::GFX9, + .ARCH_ID = amdgcn_target_arch_id::CDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE64}; +} + +static constexpr auto make_amdgcn_gfx11_target(amdgcn_target_id targetId) +{ + return amdgcn_target{.TARGET_ID = targetId, + .FAMILY_ID = amdgcn_target_family_id::GFX11, + .ARCH_ID = amdgcn_target_arch_id::RDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; +} + +static constexpr auto make_amdgcn_gfx12_target(amdgcn_target_id targetId) +{ + return amdgcn_target{.TARGET_ID = targetId, + .FAMILY_ID = amdgcn_target_family_id::GFX12, + .ARCH_ID = amdgcn_target_arch_id::RDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; +} + +static constexpr bool is_target_family_gfx9(amdgcn_target target) +{ + return target.FAMILY_ID == amdgcn_target_family_id::GFX9; +} + +static constexpr bool is_target_family_gfx10_3(amdgcn_target target) +{ + return target.FAMILY_ID == amdgcn_target_family_id::GFX10_3; +} + +static constexpr bool is_target_family_gfx11(amdgcn_target target) +{ + return target.FAMILY_ID == amdgcn_target_family_id::GFX11; +} + +static constexpr bool is_target_family_gfx12(amdgcn_target target) +{ + return target.FAMILY_ID == amdgcn_target_family_id::GFX12; +} + +static constexpr bool is_target_arch_cdna(amdgcn_target target) +{ + return target.ARCH_ID == amdgcn_target_arch_id::CDNA; +} + +static constexpr bool is_target_arch_rdna(amdgcn_target target) +{ + return target.ARCH_ID == amdgcn_target_arch_id::RDNA; +} + +static constexpr bool is_target_wave_size_32(amdgcn_target target) +{ + return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32; +} + +static constexpr bool is_target_wave_size_64(amdgcn_target target) +{ + return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64; +} + +// Helper to map compiler state to target arch id +#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \ + } + +#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \ + } + +#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \ + } + +#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \ + if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \ + { \ + return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \ + } + +/*! @brief Returns the amdgcn_target of the current compiler pass. + * @note This is where we tie the compiler state to our internal target architecture representation + * at compile time. + */ +CK_TILE_HOST_DEVICE constexpr auto get_compiler_target() +{ + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942); + MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036); + MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); + MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC); + + // Default to HOST + return amdgcn_target{}; +} + +// Cleanup +#undef MAP_COMPILER_STATE_TO_GFX9_TARGET +#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET +#undef MAP_COMPILER_STATE_TO_GFX11_TARGET +#undef MAP_COMPILER_STATE_TO_GFX12_TARGET + +// Sanity check: device compile must have a valid target architecture +static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE || + get_compiler_target().TARGET_ID != amdgcn_target_id::HOST, + "Device compile must have a valid target device architecture"); + +// Sanity check: host compile must have HOST target architecture +static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE || + get_compiler_target().TARGET_ID == amdgcn_target_id::HOST, + "Host compile must target HOST architecture"); + +#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET(NAME_STRING, TARGET_ID) \ + if constexpr(str.find(NAME_STRING) != std::string::npos) \ + { \ + return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \ + } \ + else + +#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET(NAME_STRING, TARGET_ID) \ + if constexpr(str.find(NAME_STRING) != std::string::npos) \ + { \ + return make_amdgcn_gfx10_3_target(amdgcn_target_id::TARGET_ID); \ + } \ + else + +#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET(NAME_STRING, TARGET_ID) \ + if constexpr(str.find(NAME_STRING) != std::string::npos) \ + { \ + return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \ + } \ + else + +#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET(NAME_STRING, TARGET_ID) \ + if constexpr(str.find(NAME_STRING) != std::string::npos) \ + { \ + return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \ + } \ + else + +/** + * @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value. + * Returns amdgcn_target_arch_id::HOST if no match is found. + * Matches if the input contains the architecture substring. + * Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info. + */ +CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* testStr) +{ + auto str = std::string(testStr); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx908", GFX908); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx90a", GFX90A); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx942", GFX942); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx950", GFX950); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx10_3_generic", GFX103_GENERIC); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1100", GFX1100); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1101", GFX1101); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1102", GFX1102); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1103", GFX1103); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx12_generic", GFX12_GENERIC); + + // Default case + return amdgcn_target{}; +} + +#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET +#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET +#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET +#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET + +/** + * @brief SFINAE enabler for a compiler target if the target id is in the list of supported target + * ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908 + */ +template +using enable_if_target_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the family id is in the list of supported family + * ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetFamilyIds The list of supported family ids, e.g., + * amdgcn_target_family_id::GFX9 + */ +template +using enable_if_target_family_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA + */ +template +using enable_if_target_arch_id_t = + std::enable_if_t; + +/** + * @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave + * size ids + * @tparam CompilerTarget The compiler target to check + * @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g., + * amdgcn_target_wave_size_id::WAVE64 + */ +template +using enable_if_target_wave_size_id_t = + std::enable_if_t; + +/// Specialized enablers for common families, architectures, and wave sizes /// + +/** + * @brief SFINAE enabler for GFX9 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx9_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX10.3 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx10_3_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX11 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx11_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for GFX12 family targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_family_gfx12_t = + enable_if_target_family_id_t; + +/** + * @brief SFINAE enabler for CDNA architecture targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_arch_cdna_t = + enable_if_target_arch_id_t; + +/** + * @brief SFINAE enabler for RDNA architecture targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_arch_rdna_t = + enable_if_target_arch_id_t; + +/** + * @brief SFINAE enabler for WAVE32 size targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_wave32_t = + enable_if_target_wave_size_id_t; + +/** + * @brief SFINAE enabler for WAVE64 size targets + * @tparam CompilerTarget The compiler target to check + */ +template +using enable_if_target_wave64_t = + enable_if_target_wave_size_id_t; + +#endif // __cplusplus <= 201703L + +} // namespace core::arch + CK_TILE_HOST bool is_wave32() { hipDeviceProp_t props{}; @@ -86,6 +825,13 @@ CK_TILE_HOST bool is_wave32() return props.major > 9; } +/*! @brief Returns the amdgcn_wave_size of the current compiler pass + */ +CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() +{ + return static_cast(core::arch::get_compiler_target().WAVE_SIZE_ID); +} + CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; } CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; } diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp new file mode 100644 index 0000000000..88cf189667 --- /dev/null +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -0,0 +1,118 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/utility/ignore.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct Unsupported + * @brief Meta-tag to indicate unsupported amdgcn_mma instance. + */ +struct Unsupported; + +#if defined(__cpp_concepts) && __cpp_concepts >= 201907L +/** + * @concept MmaOpI + * @brief Expresses the meta-data interface required for each MmaOp policy. + */ +template +concept MmaOpI = requires(MmaOp op) { + // Requires an op context + typename MmaOp::OpType; + + // Captures types for inputs / outputs to mma function + typename MmaOp::AVecType; + typename MmaOp::BVecType; + typename MmaOp::CVecType; + + // Captures CK-specific layout properties + { MmaOp::kAMBlock } -> std::convertible_to; + { MmaOp::kBNBlock } -> std::convertible_to; + { MmaOp::kAMLane } -> std::convertible_to; + { MmaOp::kBNLane } -> std::convertible_to; + { MmaOp::kABKLane } -> std::convertible_to; + { MmaOp::kABKPerLane } -> std::convertible_to; + { MmaOp::kCMLane } -> std::convertible_to; + { MmaOp::kCNLane } -> std::convertible_to; + { MmaOp::kCM0PerLane } -> std::convertible_to; + { MmaOp::kCM1PerLane } -> std::convertible_to; + + // Static exec function + { + MmaOp::exec( + typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{}) + } -> std::convertible_to; +}; + +#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @class amdgcn_mma + * @brief This is the default MmaOp policy. + * Instances of this class are to be used as MmaOp policies. + * Light builtin wrapper for mfma / wmma instructions. This class's job is to + * provide a uniform interface to invoke the appropriate instruction + * based on the template parameters provided. This interface is to bridge + * the gap between the ck_tile API types and the native __builtin types. + * @tparam ADataType Datatype of input A + * @tparam BDataType Datatype of input B + * @tparam CDataType Datatype of accumulator + * @tparam BlockM M-dimension of mma block + * @tparam BlockN N-dimension of mma block + * @tparam BlockK K-dimension of mma block + * @tparam CtrlFlags Control flags for mma operation + * @tparam CompilerTarget The current compiler target + * @tparam Enabler SFINAE enabler + */ +template +struct amdgcn_mma +{ + // The base instance is unsupported because there is no __builtin to wrap. + using OpType = Unsupported; + + // Interface types for A, B, C vectors types + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants - default to 0 + static constexpr index_t kAMBlock = 0; + static constexpr index_t kBNBlock = 0; + + static constexpr index_t kAMLane = 0; + static constexpr index_t kBNLane = 0; + static constexpr index_t kABKLane = 0; + static constexpr index_t kABKPerLane = 0; + + static constexpr index_t kCMLane = 0; + static constexpr index_t kCNLane = 0; + static constexpr index_t kCM0PerLane = 0; + static constexpr index_t kCM1PerLane = 0; + + // This is a default pass-through implementation that doesn't do anything practical. + CK_TILE_DEVICE static CVecType const& + exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) + { + ignore(regsA, regsB); + return regsC; // No-op, just return C + } +}; + +} // namespace ck_tile::core::arch::mma + +// Include the implementations +#include "wmma/wmma.hpp" +#include "mfma/mfma.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma.hpp b/include/ck_tile/core/arch/mma/mfma/mfma.hpp new file mode 100644 index 0000000000..34c3b11d2f --- /dev/null +++ b/include/ck_tile/core/arch/mma/mfma/mfma.hpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +// Include the architecture-specific MFMA implementations and traits +#include "mfma_gfx9.hpp" +#include "mfma_traits.hpp" +#include "mfma_selector.hpp" +#include "mfma_transforms.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp new file mode 100644 index 0000000000..94e429d385 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -0,0 +1,162 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "mfma_traits.hpp" + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +namespace ck_tile::core::arch::mma { + +// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed. +// This is because some built-ins are only available on certain target ids. +// We can also do things such add some padding specializations for when we need to use +// smaller values of K that aren't directly supported by the built-ins. +// For flexibility, it is recommended that for each backend wrapper it supports at least +// one packed register for each input to be able to process smaller K values by padding. + +/** + * @struct DefaultMmaCtrlFlags + * @brief Default MFMA flags, no broadcasting or rotation of inputs + */ +struct DefaultMfmaCtrlFlags +{ + static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0 + static constexpr uint32_t Abid = 0; // ABID flag, default 0 + static constexpr uint32_t Blgp = 0; // BLGP flag, default 0 +}; + +#if defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @concept CtrlFlagsGfx9I + * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 + */ +template +concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) { + // Flag members for Gfx9 MFMA instructions + { CtrlFlags::Cbsz } -> std::convertible_to; + { CtrlFlags::Abid } -> std::convertible_to; + { CtrlFlags::Blgp } -> std::convertible_to; +}; + +#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for MFMA on GFX9 targets + * + * This specialization implements the MFMA instruction for fp16_t A and B + * matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes. + * + * @tparam CtrlFlags Control flags for the MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma> +{ + // Mfma operation type + using OpType = MfmaOp; + + // Register types + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + CK_TILE_DEVICE static auto + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec, + bVec, + cVec, + static_cast(CtrlFlags::Cbsz), + static_cast(CtrlFlags::Abid), + static_cast(CtrlFlags::Blgp))}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for MFMA on GFX950 targets + * + * This specialization implements the MFMA instruction for fp16_t A and B + * matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes. + * + * @tparam CtrlFlags Control flags for the MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma> +{ + using OpType = MfmaOp; + + // Packed register types + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 8; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + CK_TILE_DEVICE static auto + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec, + bVec, + cVec, + static_cast(CtrlFlags::Cbsz), + static_cast(CtrlFlags::Abid), + static_cast(CtrlFlags::Blgp))}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp new file mode 100644 index 0000000000..b45da8a509 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -0,0 +1,189 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +#include "mfma_traits.hpp" +#include "mfma_gfx9.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @class MfmaDefaultSelector + * @brief Implements a default MFMA selector strategy for gfx9 target architectures. + * This implements the K dimension search strategy to find the largest supported MFMA + * instruction for the given M/N block sizes and datatypes. + * If no supported instruction is found, falls back to an unsupported pass-through + implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Block M dimension size + * @tparam BlockN Block N dimension size + * @tparam BlockKTest Current Block K dimension size to test + * @tparam CompilerTarget The compiler target + * @note Here we assume that BlockKTest is always a power-of-two integer. + * The search strategy starts from a maximum BlockKTest size down to 1u by halving + * each time. + */ +template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +struct MfmaDefaultSelector +{ + private: + // Define our candidate MFMA implementation for the current parameters + using CandidateOp = + amdgcn_mma; + using CandidateTraits = MmaOpTraits; + + public: + // If the candidate is supported (e.g., a backend implementation exists), then select it. + // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u + // and fall back to the unsupported pass-through implementation. + using SelectedOp = std::conditional_t::SelectedOp>; +}; + +/** + * @struct MfmaDefaultSelector + * @brief Implements the base case for the default MFMA selector when no supported instruction is + * found. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Block M dimension size + * @tparam BlockN Block N dimension size + * @tparam CompilerTarget The compiler target + */ +template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> +struct MfmaDefaultSelector +{ + // Default unsupported pass-through if no instruction is found + using SelectedOp = + amdgcn_mma; +}; + +/** + * @struct MmaDefaultSelector + * @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition. + * This implements the M/N block size search strategy to find the largest supported MFMA + * instruction for the given datatypes. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Size of the M dimension of the fragment to decompose + * @tparam FragN Size of the N dimension of the fragment to decompose + * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam CompilerTarget The compiler target + */ +template // TODO: c++20 amdgcn_target_arch_id CompilerTarget> +struct MmaDefaultSelector> +{ + private: + // Provide the default depth-K search strategy for each class of common MFMA shapes. + // Start searching from the largest K dimension MFMA shape down to the smallest. + using CandidateOp4x4 = + typename MfmaDefaultSelector:: + SelectedOp; + using CandidateOp16x16 = typename MfmaDefaultSelector::SelectedOp; + using CandidateOp32x32 = typename MfmaDefaultSelector::SelectedOp; + + // Default operation triggers pass-through + using DefaultOp = + typename MfmaDefaultSelector:: + SelectedOp; + + // Traits for each candidate + using CandidateTraits4x4 = MmaOpTraits; + using CandidateTraits16x16 = MmaOpTraits; + using CandidateTraits32x32 = MmaOpTraits; + + // Check if each candidate is supported for the given fragment sizes + // For this case, we require the fragment sizes to be multiples of the MFMA shape + static constexpr bool IsSupported4x4 = + CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) && + (FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u); + static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && + (FragM % CandidateTraits16x16::BlockM == 0u) && + (FragN % CandidateTraits16x16::BlockN == 0u) && + (FragK % CandidateTraits16x16::BlockK == 0u); + static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported && + (FragM % CandidateTraits32x32::BlockM == 0u) && + (FragN % CandidateTraits32x32::BlockN == 0u) && + (FragK % CandidateTraits32x32::BlockK == 0u); + + public: + // Select the largest supported MFMA operation for the given fragment shape + using SelectedOp = std::conditional_t< + IsSupported32x32, + CandidateOp32x32, + std::conditional_t>>; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp new file mode 100644 index 0000000000..b023118ab0 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp @@ -0,0 +1,44 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @struct MfmaOp + * @brief Meta-tag for the MFMA operation. This will be used in the MmaOp policies to + * identify the operation as an MFMA operation. + */ +struct MfmaOp; + +/** + * @class is_mma_op_mfma + * @brief Trait to check if MmaOp is an MFMA operation + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +struct is_mma_op_mfma : std::false_type +{ +}; + +/** + * @struct is_mma_op_mfma + * @brief MmaOp specialization for MFMA operations, confirming the OpType matches MfmaOp + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +// TODO: c++20 requires +struct is_mma_op_mfma>> + : std::true_type +{ +}; + +/** + * @brief Convenience evaluator for is_mma_op_mfma trait + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma::value; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp new file mode 100644 index 0000000000..589e6e049c --- /dev/null +++ b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp @@ -0,0 +1,38 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct MmaDefaultTransformsGfx9 + * @brief Implements the default MMA transforms for gfx9 targets + */ +struct MmaDefaultTransformsGfx9 +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + +/** + * @struct MmaTransformsDefaultSelector + * @brief Implements the default MMA transforms selection for gfx9 targets + * @tparam MmaOp Mma operation + * @tparam CompilerTarget The compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct MmaTransformsDefaultSelector> +{ + using SelectedTransforms = MmaDefaultTransformsGfx9; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp new file mode 100644 index 0000000000..032261eb52 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -0,0 +1,234 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +#include "amdgcn_mma.hpp" +#include "mma_selector.hpp" +#include "mma_traits.hpp" +#include "mma_transforms.hpp" + +#include "mfma/mfma.hpp" +#include "wmma/wmma.hpp" + +namespace ck_tile::core::arch::mma { + +/*! @enum MmaAccumPolicy + * @brief Accumulation order for Mma decomposition + */ +enum struct MmaAccumPolicy +{ + // Decomposition and accumulation in row-major block order + ROW_MAJOR, + // Decomposition and accumulation in col-major block order + COL_MAJOR +}; + +/** + * @class Mma + * @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation + * (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input + * fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment + * (C: FragM x FragN). + * @tparam ADataType Data type of input fragment A + * @tparam BDataType Data type of input fragment B + * @tparam CDataType Data type of input/output fragment C (accumulator) + * @tparam FragM Mma fragment M dimension + * @tparam FragN Mma fragment K dimension + * @tparam FragK Mma fragment M dimension + * @tparam AccumPolicy The block order of the accumulation registers (row major or col major block + * order) + * @tparam CompilerTarget The compiler target + * @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or + * wmma) + * @tparam MmaTransforms The set of transforms to be applied to input/output fragments + * @par This is an example of an Mma decomposition driver class that can be used in a wave-tile + * context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops + * that are natively supported by the hardware (e.g., mfma or wmma). The class also supports + * applying transforms to the input/output fragments as needed (e.g., layout conversions, data type + * conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the + * output fragment. This is a powerful example of how to build a flexible and reusable mma driver + * that can adapt to different hardware capabilities and requirements. + */ +template ::SelectedOp, + typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = + typename MmaTransformsDefaultSelector::SelectedTransforms> +struct WaveWiseMma +{ + + using BlockWiseMmaOp = MmaOp; + using BlockWiseMmaOpTraits = MmaOpTraits; + + // Block dimensions + constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM; + constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN; + constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK; + + // Block counts for decomposition + constexpr static uint32_t BlocksM = FragM / BlockM; + constexpr static uint32_t BlocksN = FragN / BlockN; + constexpr static uint32_t BlocksK = FragK / BlockK; + constexpr static uint32_t BlocksC = BlocksM * BlocksN; + + // Vector types for packed registers in each block + using AVecType = typename BlockWiseMmaOpTraits::AVecType; + using BVecType = typename BlockWiseMmaOpTraits::BVecType; + using CVecType = typename BlockWiseMmaOpTraits::CVecType; + + // Buffer types for fragments + using ABufferType = AVecType[BlocksM][BlocksK]; + using BBufferType = BVecType[BlocksN][BlocksK]; + using CBufferType = CVecType[BlocksM][BlocksN]; + + // Transforms + using ATransform = typename MmaTransforms::ATransform; + using BTransform = typename MmaTransforms::BTransform; + using CTransform = typename MmaTransforms::CTransform; + using DTransform = typename MmaTransforms::DTransform; + + // Sanity checks + static_assert(FragM >= BlockM, "FragM must be larger than BlockM"); + static_assert(FragN >= BlockN, "FragN must be larger than BlockN"); + static_assert(FragK >= BlockK, "FragK must be larger than BlockK"); + static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM"); + static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN"); + static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK"); + + private: + template + CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer) + { + // TODO: Implement formatting logic as needed. + // This is intended to convert input fragments to the native vector types + // required by the BlockWiseMma operation for iteration + static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); + return reinterpret_cast(inputBuffer); + } + + template + CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer) + { + // TODO: Implement formatting logic as needed. + // This is intended to convert input fragments to the native vector types + // required by the BlockWiseMma operation for iteration + static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); + return reinterpret_cast(inputBuffer); + } + + /*! @brief Execute Mma in row-major accumulation order. + * @tparam VecTA The input fragment A vector type + * @tparam VecTB The input fragment B vector type + * @tparam VecTC The input/output fragment C vector type + */ + template + CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum) + { + // We implement an example wave-tile pipeline here. + // First, we apply the necessary transforms to the input fragments, + // then we convert the result into buffers of native vector formats + // that we can easily index. Native vector formats are necessary inputs + // to the given MmaOp exec function. + auto a_frag = formatBuffer(ATransform::exec(a)); + auto b_frag = formatBuffer(BTransform::exec(b)); + auto c_frag = formatBuffer(CTransform::exec(accum)); + + // "Col-major" accumulation over the M-dimension blocks first. + // Pseudo code here, but we would basically iterate over the blocks in col-major order + for(uint32_t bn = 0u; bn < BlocksN; ++bn) + { + for(uint32_t bm = 0u; bm < BlocksM; ++bm) + { + for(uint32_t bk = 0u; bk < BlocksK; ++bk) + { + c_frag[bm][bn] = + BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } + } + } + + // Convert native vector results back to the output fragment format + // and then return after we apply the final output transform. + return DTransform::exec(formatBuffer>(c_frag)); + } + + /*! @brief Execute Mma in row-major accumulation order. + * @tparam VecTA The input fragment A vector type + * @tparam VecTB The input fragment B vector type + * @tparam VecTC The input/output fragment C vector type + */ + template + CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum) + { + // We implement an example wave-tile pipeline here. + // First, we apply the necessary transforms to the input fragments, + // then we convert the result into buffers of native vector formats + // that we can easily index. Native vector formats are necessary inputs + // to the given MmaOp exec function. + auto a_frag = formatBuffer(ATransform::exec(a)); + auto b_frag = formatBuffer(BTransform::exec(b)); + auto c_frag = formatBuffer(CTransform::exec(accum)); + + // "Row-major" accumulation over the N-dimension blocks first. + // Pseudo code here, but we would basically iterate over the blocks in row-major order. + // We also have to ensure that the incoming vector fragments are converted to native vector + // types before passing to the BlockWiseMma exec function. + for(uint32_t bm = 0u; bm < BlocksM; ++bm) + { + for(uint32_t bn = 0u; bn < BlocksN; ++bn) + { + for(uint32_t bk = 0u; bk < BlocksK; ++bk) + { + c_frag[bm][bn] = + BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } + } + } + + // Convert native vector results back to the output fragment format + // and then return after we apply the final output transform. + return DTransform::exec(formatBuffer>(c_frag)); + } + + public: + /*! @brief Forward to Mma operation with specified accumulation order. + * @tparam VecTA The input fragment A vector type + * @tparam VecTB The input fragment B vector type + * @tparam VecTC The input/output fragment C vector type + */ + template + CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) + { + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + return exec_row_major( + std::forward(a), std::forward(b), std::forward(accum)); + } + else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + return exec_col_major( + std::forward(a), std::forward(b), std::forward(accum)); + } + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp new file mode 100644 index 0000000000..b2845e9bb2 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -0,0 +1,63 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +namespace ck_tile::core::arch::mma { + +/** + * @class MmaDefaultSelector + * @brief Implements a default mma selector strategy for the current target architecture. + * This is simply intended as a default selection strategy for mma instruction operations. + * Given the particular datatypes and Fragment dimensions, the selector will attempt to + * select the instruction with the largest K dimension that is supported on the current target + * architecture. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Fragment M dimension + * @tparam FragN Fragment N dimension + * @tparam FragK Fragment K dimension + * @tparam CompilerTarget The compiler target + * @tparam Enable SFINAE enabler + * @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA + * operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes + * correspond to the size of the individual MMA instructions being used to compute the overall in + * block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or + * equal to the Block sizes. + */ +template +// TODO c++20 requires +struct MmaDefaultSelector +{ + // By default, no selection is made, and we fall back to a pass-through unsupported + // implementation. This is because we do not have any knowledge of the target architecture here. + using SelectedOp = + amdgcn_mma>; +}; + +#if defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @concept MmaSelectorI + * @brief Expresses the required members for each MmaSelector class. + */ +template +concept MmaSelectorI = requires(MmaSelector op) { + // Selectors should have a resulting SelectedOp type + typename MmaSelector::SelectedOp; +}; + +#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L + +} // namespace ck_tile::core::arch::mma + +// Include the implementations +#include "wmma/wmma_selector.hpp" +#include "mfma/mfma_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp new file mode 100644 index 0000000000..29b7e106cb --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -0,0 +1,151 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "amdgcn_mma.hpp" +#include "mfma/mfma_traits.hpp" +#include "wmma/wmma_traits.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @class is_mma_op_supported + * @brief Trait to check if MmaOp is supported + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +// TODO: c++20 template +template +struct is_mma_op_supported : std::true_type +{ +}; + +/** + * @struct is_mma_op_supported + * @brief The MmaOp is unsupported specialization + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +// TODO: c++20 template +template +// TODO: c++20 requires +struct is_mma_op_supported>> + : std::false_type +{ +}; + +/** + * @brief Convenience evaluation of is_mma_op_supported + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +// TODO: c++20 template +template +static constexpr bool is_mma_op_supported_v = is_mma_op_supported::value; + +/** + * @class MmaOpParams + * @brief Reflects the template parameters of a given MmaOp + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +// TODO: c++20 template +template +struct MmaOpParams; + +#if defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @concept MmaOpParamsI + * @brief Expresses the required members for each MmaOp + */ +template +concept MmaOpParamsI = requires(MmaOpParams op) { + // Capture template parameters + typename MmaOpParams::ADataType; + typename MmaOpParams::BDataType; + typename MmaOpParams::CDataType; + typename MmaOpParams::CtrlFlags; + + { MmaOpParams::BlockM } -> std::convertible_to; + { MmaOpParams::BlockN } -> std::convertible_to; + { MmaOpParams::BlockK } -> std::convertible_to; + { MmaOpParams::GfxTargetId } -> std::convertible_to; +}; + +#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @struct MmaOpParams + * @brief Reflects the template parameters of a given MmaOp + * @tparam ADataType_ Data type of matrix A + * @tparam BDataType_ Data type of matrix B + * @tparam CDataType_ Data type of the accumulator + * @tparam BlockM_ Size of the M dimension + * @tparam BlockN_ Size of the N dimension + * @tparam BlockK_ Size of the K dimension + * @tparam CtrlFlags_ Control flags for the MMA operation + * @tparam CompilerTarget_ The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget_> +struct MmaOpParams> +{ + // Capture incoming template parameters + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + static constexpr uint32_t BlockM = BlockM_; + static constexpr uint32_t BlockN = BlockN_; + static constexpr uint32_t BlockK = BlockK_; + using CtrlFlags = CtrlFlags_; + using CompilerTarget = CompilerTarget_; + // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_; +}; + +/** + * @class MmaOpTraits + * @brief Reflects the template parameters and static members of a given MmaOp. + * @tparam MmaOp The matrix multiply-accumulate operation + */ +template +// TODO: c++20 template +// TODO: c++20 requires MmaOpParamsI> +struct MmaOpTraits : public MmaOpParams +{ + // Capture internal MmaOp static members + using OpType = typename MmaOp::OpType; + using AVecType = typename MmaOp::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; + + // Capture layout parameters + static constexpr index_t kAMBlock = MmaOp::kAMBlock; + static constexpr index_t kBNBlock = MmaOp::kBNBlock; + static constexpr index_t kAMLane = MmaOp::kAMLane; + static constexpr index_t kBNLane = MmaOp::kBNLane; + static constexpr index_t kABKLane = MmaOp::kABKLane; + static constexpr index_t kABKPerLane = MmaOp::kABKPerLane; + static constexpr index_t kCMLane = MmaOp::kCMLane; + static constexpr index_t kCNLane = MmaOp::kCNLane; + static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane; + static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane; + + // Additional traits to identify the type of MmaOp at compile time + constexpr static bool IsMfma = is_mma_op_mfma_v; + constexpr static bool IsWmma = is_mma_op_wmma_v; + constexpr static bool IsSupported = is_mma_op_supported_v; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_transforms.hpp b/include/ck_tile/core/arch/mma/mma_transforms.hpp new file mode 100644 index 0000000000..bbb0050084 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_transforms.hpp @@ -0,0 +1,48 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @struct PassThroughTransform + * @brief A no-op transform that passes through the input as-is. + */ +struct PassThroughTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) + { + return std::forward(v); + } +}; + +/** + * @class MmaTransformsDefaultSelector + * @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget + * @tparam MmaOp The Mma operation type + * @tparam CompilerTarget The compiler target + * @tparam Enable SFINAE parameter for specialization + */ +template +// TODO: c++20 template +struct MmaTransformsDefaultSelector; + +#if defined(__cpp_concepts) && __cpp_concepts >= 201907L + +/** + * @concept MmaTransformsI + * @brief Expresses the interface of required members for each MmaTransforms type. + */ +template +concept MmaTransformsI = requires(MmaTransforms transforms) { + // Transforms should define TransformA, TransformB, TransformC, and TransformD types + typename MmaTransforms::ATransform; + typename MmaTransforms::BTransform; + typename MmaTransforms::CTransform; + typename MmaTransforms::DTransform; +}; + +#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma.hpp b/include/ck_tile/core/arch/mma/wmma/wmma.hpp new file mode 100644 index 0000000000..8f79478b38 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma.hpp @@ -0,0 +1,34 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @enum WmmaCtrlFlags + * @brief Common wmma control flags for gfx11 and gfx12 + */ +enum struct WmmaCtrlFlags : bool +{ + // Only has an effect on gfx11 when the accumulator is 16-bit + // Determines which half of the 32-bit accum register to use + // Low = bits [15:0] + // High = bits[31:16] + LOW = false, + HIGH = true, + + // Only has an effect on gfx11 / 12 when the input is 8-bit int + // Signage indicator of inputs / accum + UNSIGNED = false, + SIGNED = true +}; + +} // namespace ck_tile::core::arch::mma + +// Include the architecture-specific WMMA implementations and traits +#include "wmma_gfx11.hpp" +#include "wmma_gfx12.hpp" +#include "wmma_selector.hpp" +#include "wmma_traits.hpp" +#include "wmma_transforms.hpp" diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp new file mode 100644 index 0000000000..355fe6c957 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -0,0 +1,109 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "wmma_traits.hpp" + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +namespace ck_tile::core::arch::mma { +// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as: +// - Duplicating A and B inputs +// - Handling C / D is always in b32, even for f16 accumulation. +// NOTE: Two suggestions: +// 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to +// duplicate the inputs as needed before calling the actual built-in. This may introduce +// some instruction overhead and violate single responsibility clauses, but keeps the logic +// contained within the backend wrapper. +// 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing +// pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform +// post-mma could implement D-shuffle transform. This may be cleaner and more flexible than +// trying to handle everything in the backend wrappers. +// +// This current example assumes duplication has already been done, and that C data shuffles have +// already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled +// data in C. + +// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed. +// This is because some built-ins are only available on certain target ids. +// We can also do things, such add some padding specializations for when we need to use +// smaller values of K that aren't directly supported by the built-ins. +// For flexibility, it is recommended that for each backend wrapper it supports at least +// one packed register for each input to be able to process smaller K values by padding. + +/** + * @class DefaultWmmaFlags + * @brief Generates default WMMA control flags based on data types. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + */ +template +struct DefaultWmmaCtrlFlags +{ + // Generate default flags for signage + // Only used currently for integer inputs / accum in gfx11 / gfx12 + constexpr static WmmaCtrlFlags InputSignA = + std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; + constexpr static WmmaCtrlFlags InputSignB = + std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; + constexpr static WmmaCtrlFlags AccumSign = + std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; + + // Generate default flags for accumulator destination bits. + // Only used if accumulation size is 16-bit in gfx11 + constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW; +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma> +{ + // Wmma operation type + using OpType = WmmaOp; + + // Register types (duplicated input / b32 accum) + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 8; + static constexpr index_t kABKPerLane = 8; + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 2; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 1; + + CK_TILE_DEVICE static auto + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp new file mode 100644 index 0000000000..c41224b995 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -0,0 +1,69 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "wmma_traits.hpp" + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +namespace ck_tile::core::arch::mma { + +// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed. +// This is because some built-ins are only available on certain target ids. +// We can also do things, such add some padding specializations for when we need to use +// smaller values of K that aren't directly supported by the built-ins. +// For flexibility, it is recommended that for each backend wrapper it supports at least +// one packed register for each input to be able to process smaller K values by padding. + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma> +{ + // Wmma operation type + using OpType = WmmaOp; + + // Register types + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 8; + static constexpr index_t kABKPerLane = 8; + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 2; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 1; + + CK_TILE_DEVICE static auto + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp new file mode 100644 index 0000000000..401d672126 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -0,0 +1,161 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @class WmmaDefaultSelector + * @brief Implements a default WMMA selector strategy for gfx11/12 target architectures. + * This implements the K dimension search strategy to find the largest supported WMMA + * instruction for the given M/N block sizes and datatypes. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Size of the M dimension + * @tparam BlockN Size of the N dimension + * @tparam BlockKTest Size of the K dimension + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +struct WmmaDefaultSelector +{ + private: + // By default, let's assume no special flags for WMMA + using CtrlFlags = DefaultWmmaCtrlFlags; + + // Define our candidate WMMA implementation for the current parameters + using CandidateOp = amdgcn_mma; + + using CandidateTraits = MmaOpTraits; + + public: + // If the candidate is supported (e.g., a backend implementation exists), then select it. + // Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u + // and fall back to the unsupported pass-through implementation. + using SelectedOp = std::conditional_t::SelectedOp>; +}; + +/** + * @struct WmmaDefaultSelector + * @brief Implements a default WMMA selector strategy for gfx11/12 target architectures. + * This implements the K dimension == 1, which is the base case for the recursive K dimension + * search. If no supported instruction is found, falls back to an unsupported pass-through + * implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Size of the M dimension + * @tparam BlockN Size of the N dimension + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id GfxTargetId> +struct WmmaDefaultSelector +{ + // By default, let's assume no special flags for WMMA + using CtrlFlags = DefaultWmmaCtrlFlags; + + // Default unsupported pass-through if no instruction is found + using SelectedOp = + amdgcn_mma; +}; + +/** + * @struct MmaDefaultSelector + * @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition. + * This implements the M/N block size search strategy to find the largest supported WMMA + * instruction for the given datatypes. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Size of the M dimension of the fragment to decompose + * @tparam FragN Size of the N dimension of the fragment to decompose + * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires +struct MmaDefaultSelector> +{ + private: + // Provide the default depth-K search strategy for each class of common WMMA shapes. + // Start searching from the largest K dimension MFMA shape down to the smallest. + using CandidateOp16x16 = typename WmmaDefaultSelector::SelectedOp; + + // Default operation triggers pass-through + using DefaultOp = + typename WmmaDefaultSelector:: + SelectedOp; + + // Traits for each candidate + using CandidateTraits16x16 = MmaOpTraits; + + // Check if each candidate is supported for the given fragment sizes + // For this case, we require the fragment sizes to be multiples of the WMMA shape + static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && + (FragM % CandidateTraits16x16::BlockM == 0u) && + (FragN % CandidateTraits16x16::BlockN == 0u) && + (FragK % CandidateTraits16x16::BlockK == 0u); + + public: + // Select the largest supported WMMA operation for the given fragment shape + using SelectedOp = std::conditional_t; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp new file mode 100644 index 0000000000..9e2e42a9d7 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp @@ -0,0 +1,44 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @struct WmmaOp + * @brief Meta-tag for the WMMA operation. This will be used in the MmaOp struct to + * identify the operation as an WMMA operation. + */ +struct WmmaOp; + +/** + * @class is_mma_op_wmma + * @brief Trait to check if MmaOp is an WMMA operation + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +struct is_mma_op_wmma : std::false_type +{ +}; + +/** + * @struct is_mma_op_wmma + * @brief MmaOp specialization for WMMA operations, confirming the OpType matches WmmaOp + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +// TODO: c++20 requires +struct is_mma_op_wmma>> + : std::true_type +{ +}; + +/** + * @brief Convenience evaluator for is_mma_op_wmma trait + * @tparam MmaOp The matrix multiply-accumulate operation type to check + */ +template +static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma::value; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp new file mode 100644 index 0000000000..2877e8f1f8 --- /dev/null +++ b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -0,0 +1,112 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct DuplicateTransform + * @brief Transform to duplicate low register elements to high register elements + */ +struct DuplicateTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) + { + // TODO: Implement duplication logic to broadcast low + // register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)] + return std::forward(v); + } +}; + +/** + * @struct PadTransform + * @brief Transform to pad data from original type to b32 type + */ +struct PadTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) + { + // TODO: Implement b32 padding logic. + // E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements + return std::forward(v); + } +}; + +/** + * @struct UnpadTransform + * @brief Transform to unpad data from b32 type to original type + */ +struct UnpadTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) + { + // TODO: Implement b32 logic to unpad 32 to original data type. + return std::forward(v); + } +}; + +/** + * @struct MmaDefaultTransformsGfx11 + * @brief Default MMA transforms for GFX11 architecture + */ +struct MmaDefaultTransformsGfx11 +{ + using ATransform = DuplicateTransform; + using BTransform = DuplicateTransform; + using CTransform = PadTransform; + using DTransform = UnpadTransform; +}; + +/** + * @struct MmaDefaultTransformsGfx12 + * @brief Default MMA transforms for GFX12 architecture + */ +struct MmaDefaultTransformsGfx12 +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + +/** + * @struct MmaTransformsDefaultSelector + * @brief Implements the default MMA transforms selection for gfx11 targets + * @tparam MmaOp Mma operation + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 template +// TODO: c++20 requires +struct MmaTransformsDefaultSelector> +{ + using SelectedTransforms = MmaDefaultTransformsGfx11; +}; + +/** + * @struct MmaTransformsDefaultSelector + * @brief Implements the default MMA transforms selection for gfx12 targets + * @tparam MmaOp Mma operation + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 template +// TODO: c++20 requires +struct MmaTransformsDefaultSelector> +{ + using SelectedTransforms = MmaDefaultTransformsGfx12; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index b01f9dedef..91e6134ac8 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -284,3 +284,242 @@ #ifndef CK_TILE_ENC_SUPPORT_Y_TO_R #define CK_TILE_ENC_SUPPORT_Y_TO_R 0 #endif + +// Mark unsupported features with a deprecation warning in debug builds +#if defined(NDEBUG) +#define CK_TILE_UNSUPPORTED_IMPL(MSG) +#else +#define CK_TILE_UNSUPPORTED_IMPL(MSG) __attribute__((deprecated(MSG))) +#endif + +namespace ck_tile::core { +/** + * @struct amdgcn_compiler_target_state + * @brief Defines compiler states for supported AMDGCN devices. + * @var CK_TILE_HOST_COMPILE Indicates if the compilation is for the host. + * @var CK_TILE_DEVICE_COMPILE Indicates if the compilation is for AMDGCN device. + * @var CK_TILE_ARCH_GFX908 Indicates if the compiler target architecture is GFX908. + * @var CK_TILE_ARCH_GFX90A Indicates if the compiler target architecture is GFX90A. + * @var CK_TILE_ARCH_GFX942 Indicates if the compiler target architecture is GFX942. + * @var CK_TILE_ARCH_GFX950 Indicates if the compiler target architecture is GFX950. + * @var CK_TILE_ARCH_GFX1030 Indicates if the compiler target architecture is GFX1030. + * @var CK_TILE_ARCH_GFX1031 Indicates if the compiler target architecture is GFX1031. + * @var CK_TILE_ARCH_GFX1032 Indicates if the compiler target architecture is GFX1032. + * @var CK_TILE_ARCH_GFX1034 Indicates if the compiler target architecture is GFX1034. + * @var CK_TILE_ARCH_GFX1035 Indicates if the compiler target architecture is GFX1035. + * @var CK_TILE_ARCH_GFX1036 Indicates if the compiler target architecture is GFX1036. + * @var CK_TILE_ARCH_GFX10_3_GENERIC Indicates if the compiler target architecture is GFX10.3 + * generic. + * @var CK_TILE_ARCH_GFX1100 Indicates if the compiler target architecture is GFX1100. + * @var CK_TILE_ARCH_GFX1101 Indicates if the compiler target architecture is GFX1101. + * @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102. + * @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151. + * @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152. + * @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic. + * @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200. + * @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201. + * @var CK_TILE_ARCH_GFX12_GENERIC Indicates if the compiler target architecture is GFX12 generic. + */ +struct amdgcn_compiler_target_state +{ + // Determine if we are compiling for device or host +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ + static constexpr bool CK_TILE_DEVICE_COMPILE = true; + static constexpr bool CK_TILE_HOST_COMPILE = false; +#else + static constexpr bool CK_TILE_DEVICE_COMPILE = false; + static constexpr bool CK_TILE_HOST_COMPILE = true; +#endif // __HIP_DEVICE_COMPILE__ && __HIP_DEVICE_COMPILE__ + + // GFX9 +#if defined(__gfx908__) + static constexpr bool CK_TILE_ARCH_GFX908 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX908 = false; +#endif // __gfx908__ + +#if defined(__gfx90a__) + static constexpr bool CK_TILE_ARCH_GFX90A = true; +#else + static constexpr bool CK_TILE_ARCH_GFX90A = false; +#endif // __gfx90a__ + +#if defined(__gfx942__) + static constexpr bool CK_TILE_ARCH_GFX942 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX942 = false; +#endif // __gfx942__ + +#if defined(__gfx950__) + static constexpr bool CK_TILE_ARCH_GFX950 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX950 = false; +#endif // __gfx950__ + + // GFX10 +#if defined(__gfx1030__) + static constexpr bool CK_TILE_ARCH_GFX1030 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1030 = false; +#endif // __gfx1030__ + +#if defined(__gfx1031__) + static constexpr bool CK_TILE_ARCH_GFX1031 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1031 = false; +#endif // __gfx1031__ + +#if defined(__gfx1032__) + static constexpr bool CK_TILE_ARCH_GFX1032 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1032 = false; +#endif // __gfx1032__ + +#if defined(__gfx1034__) + static constexpr bool CK_TILE_ARCH_GFX1034 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1034 = false; +#endif // __gfx1034__ + +#if defined(__gfx1035__) + static constexpr bool CK_TILE_ARCH_GFX1035 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1035 = false; +#endif // __gfx1035__ + +#if defined(__gfx1036__) + static constexpr bool CK_TILE_ARCH_GFX1036 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1036 = false; +#endif // __gfx1036__ + +#if defined(__gfx10_3_generic__) + static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = true; +#else + static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = false; +#endif // __gfx10_3_generic__ + + // GFX11 +#if defined(__gfx1100__) + static constexpr bool CK_TILE_ARCH_GFX1100 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1100 = false; +#endif // __gfx1100__ + +#if defined(__gfx1101__) + static constexpr bool CK_TILE_ARCH_GFX1101 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1101 = false; +#endif // __gfx1101__ + +#if defined(__gfx1102__) + static constexpr bool CK_TILE_ARCH_GFX1102 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1102 = false; +#endif // __gfx1102__ + +#if defined(__gfx1103__) + static constexpr bool CK_TILE_ARCH_GFX1103 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1103 = false; +#endif // __gfx1103__ + +#if defined(__gfx1150__) + static constexpr bool CK_TILE_ARCH_GFX1150 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1150 = false; +#endif // __gfx1150__ + +#if defined(__gfx1151__) + static constexpr bool CK_TILE_ARCH_GFX1151 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1151 = false; +#endif // __gfx1151__ + +#if defined(__gfx1152__) + static constexpr bool CK_TILE_ARCH_GFX1152 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1152 = false; +#endif // __gfx1152__ + +#if defined(__gfx11_generic__) + static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true; +#else + static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = false; +#endif // __gfx11_generic__ + + // GFX12 +#if defined(__gfx1200__) + static constexpr bool CK_TILE_ARCH_GFX1200 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1200 = false; +#endif // __gfx1200__ + +#if defined(__gfx1201__) + static constexpr bool CK_TILE_ARCH_GFX1201 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1201 = false; +#endif // __gfx1201__ + +#if defined(__gfx12_generic__) + static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = true; +#else + static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = false; +#endif // __gfx12_generic__ +}; + +/** + * @brief Helper to count the number of times an item is contained within a list of values + * @tparam T The type of the search value + * @tparam Ts The types of the search list values + * @param search The value to search for + * @param searchList The list of values to search in + * @return true if the search value is in the search list, false otherwise + */ +template +// TODO: c++20 concept requires((std::is_convertible::value && ...) && (sizeof...(Ts) >= +// 1)) +CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... searchList) +{ + static_assert((std::is_convertible::value && ...), + "All search list values must be convertible to the search value type"); + static_assert(sizeof...(Ts) >= 1, "At least one value must be provided to search in"); + + return (static_cast(search == static_cast(searchList)) + ...); +} + +#define CK_TILE_COMPILER_TARGETS_LIST \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX908, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_3_GENERIC, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1100, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1101, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1102, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1103, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX12_GENERIC + +// Sanity check: make sure only one target architecture is defined during device compile +static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE || + count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 1u, + "Only one target architecture can be defined during device compile"); + +// Sanity check: make sure no device target architecture is defined during host compile +static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE || + count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 0u, + "No device target architecture can be defined during host compile"); + +} // namespace ck_tile::core diff --git a/include/ck_tile/core/utility/ignore.hpp b/include/ck_tile/core/utility/ignore.hpp index eead914954..b15a19aa2e 100644 --- a/include/ck_tile/core/utility/ignore.hpp +++ b/include/ck_tile/core/utility/ignore.hpp @@ -1,5 +1,5 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once @@ -14,6 +14,10 @@ struct ignore_t constexpr void operator=(T&&) const noexcept { } + template + constexpr void operator()(T&&...) const noexcept + { + } }; } // namespace detail diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index c43a64edaa..5ed49b7249 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -128,6 +128,25 @@ struct is_any_of { }; +/** + * @brief Helper to check if a value is in a list of values + * @tparam T The type of the search value + * @tparam Ts The types of the search list values + * @param search The value to search for + * @param searchList The list of values to search in + * @return true if the search value is in the search list, false otherwise + */ +template +// TODO: c++20 requires((std::is_convertible::value && ...) && (sizeof...(Ts) >= 1)) +CK_TILE_HOST_DEVICE static constexpr bool is_any_value_of(T search, Ts... searchList) +{ + static_assert((std::is_convertible::value && ...), + "All searchList values must be convertible to the type of search"); + static_assert(sizeof...(Ts) >= 1, "searchList must contain at least one value"); + + return ((search == static_cast(searchList)) || ...); +} + // Helper to check if a type is a specialization of a given template template class RefTemplate> struct is_specialization_of : std::false_type diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index d58c80377a..d4cef34ce0 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(add_rmsnorm2d_rdquant) add_subdirectory(gemm_block_scale) add_subdirectory(utility) add_subdirectory(reduce) +add_subdirectory(core) add_subdirectory(epilogue) add_subdirectory(atomic_add_op) add_subdirectory(fmha) diff --git a/test/ck_tile/core/CMakeLists.txt b/test/ck_tile/core/CMakeLists.txt new file mode 100644 index 0000000000..a0479470dd --- /dev/null +++ b/test/ck_tile/core/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(arch) diff --git a/test/ck_tile/core/arch/CMakeLists.txt b/test/ck_tile/core/arch/CMakeLists.txt new file mode 100644 index 0000000000..9e7aa0e197 --- /dev/null +++ b/test/ck_tile/core/arch/CMakeLists.txt @@ -0,0 +1,13 @@ +add_subdirectory(mma) + +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_arch test_arch.cpp) + target_compile_options(test_arch PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping test_arch tests for current target") +endif() diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt new file mode 100644 index 0000000000..07eccdcd90 --- /dev/null +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -0,0 +1,12 @@ +# Currently ck_tile_gemm is only built on gfx94/gfx95 +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) + target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target") +endif() diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp new file mode 100644 index 0000000000..4121e199e2 --- /dev/null +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -0,0 +1,682 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/hip_check_error.hpp" + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +// Dummy values for testing +constexpr uint32_t DummyTargetIdVal = 55555u; +using DummyCompilerTarget = amdgcn_target(DummyTargetIdVal)>; +struct DummyOpType; +struct DummyCtrlFlags +{ +}; + +/** @brief Returns true if the given target id matches the dummy */ +constexpr bool is_dummy_target(DummyCompilerTarget dummy) +{ + return static_cast(dummy.TARGET_ID) == DummyTargetIdVal; +} + +// Enable if for dummy architecture ID +// TODO: c++20 template +template +using enable_if_target_id_dummy_t = std::enable_if_t; + +// Specialization of amdgcn_mma for a supported dummy architecture. +// This way, we don't have to worry about underlying architectural details, +// and can focus on testing the mechanism of selecting supported vs unsupported architectures. +// TODO: c++20 template +template +struct amdgcn_mma> +{ + // Mfma operation type + using OpType = DummyOpType; + + // Register types + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + // Layout constants + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 2; + + static constexpr index_t kAMLane = 3; + static constexpr index_t kBNLane = 4; + static constexpr index_t kABKLane = 5; + static constexpr index_t kABKPerLane = 6; + + static constexpr index_t kCMLane = 7; + static constexpr index_t kCNLane = 8; + static constexpr index_t kCM0PerLane = 9; + static constexpr index_t kCM1PerLane = 10; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) + { + return regsA + regsB + regsC; // Simple operation for testing + } +}; + +// Have an alias so we can test supported arch vs unsupported arch +// TODO: c++20 template +template +using DummyAmdgcnMma = + amdgcn_mma; + +/*! @struct MmaDefaultSelector + * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Size of the M dimension of the fragment to decompose + * @tparam FragN Size of the N dimension of the fragment to decompose + * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: requires +struct MmaDefaultSelector> +{ + using SelectedOp = DummyAmdgcnMma; +}; + +// Test case for supported architecture +TEST(TestAmdgcnMma, ArchSupported) +{ + // Instantiate MmaOp with the dummy supported CompilerTarget + using MmaOp = DummyAmdgcnMma; + + EXPECT_TRUE((!std::is_same_v)); + + // Additional tests for DummyArchSupported: check all member variables and types + + // Check OpType + EXPECT_TRUE( + (std::is_same::value)); // OpType is DummyOpType + + // Check AVecType, BVecType, CVecType + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + + // Check layout constants + EXPECT_EQ(MmaOp::kAMBlock, 1); + EXPECT_EQ(MmaOp::kBNBlock, 2); + + EXPECT_EQ(MmaOp::kAMLane, 3); + EXPECT_EQ(MmaOp::kBNLane, 4); + EXPECT_EQ(MmaOp::kABKLane, 5); + EXPECT_EQ(MmaOp::kABKPerLane, 6); + + EXPECT_EQ(MmaOp::kCMLane, 7); + EXPECT_EQ(MmaOp::kCNLane, 8); + EXPECT_EQ(MmaOp::kCM0PerLane, 9); + EXPECT_EQ(MmaOp::kCM1PerLane, 10); +} + +// Test case for unsupported architecture +TEST(TestAmdgcnMma, ArchUnsupported) +{ + // Instantiate MmaOp with the dummy unsupported CompilerTarget (e.g., HOST) + using MmaOp = DummyAmdgcnMma>; + + // OpType should be Unsupported + EXPECT_TRUE((std::is_same::value)); + + // AVecType, BVecType, CVecType should match default + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + + // Layout constants should match default values (typically 0) + EXPECT_EQ(MmaOp::kAMBlock, 0); + EXPECT_EQ(MmaOp::kBNBlock, 0); + + EXPECT_EQ(MmaOp::kAMLane, 0); + EXPECT_EQ(MmaOp::kBNLane, 0); + EXPECT_EQ(MmaOp::kABKLane, 0); + EXPECT_EQ(MmaOp::kABKPerLane, 0); + + EXPECT_EQ(MmaOp::kCMLane, 0); + EXPECT_EQ(MmaOp::kCNLane, 0); + EXPECT_EQ(MmaOp::kCM0PerLane, 0); + EXPECT_EQ(MmaOp::kCM1PerLane, 0); +} + +// Kernel to test amdgcn_mma::exec on device +template +__global__ void test_amdgcn_mma_exec_kernel(typename MmaOp::AVecType* a, + typename MmaOp::BVecType* b, + typename MmaOp::CVecType* c, + typename MmaOp::CVecType* out) +{ + // This is pseudo-mma behaviour to check the mechanics of mma. + // All threads write to the same values, so ensure that + // the inputs are uniform! + *out = MmaOp::exec(*a, *b, *c); +} + +TEST(TestAmdgcnMma, ArchSupportedExecDeviceOutput) +{ + using MmaOp = DummyAmdgcnMma; + using DataType = fp32_t; + + typename MmaOp::AVecType h_a; + typename MmaOp::BVecType h_b; + typename MmaOp::CVecType h_c; + typename MmaOp::CVecType h_out; + + // Fill input vectors with known values + for(size_t i = 0; i < sizeof(h_a) / sizeof(DataType); ++i) + { + reinterpret_cast(&h_a)[i] = static_cast(i + 1); + } + for(size_t i = 0; i < sizeof(h_b) / sizeof(DataType); ++i) + { + reinterpret_cast(&h_b)[i] = static_cast(i + 10); + } + for(size_t i = 0; i < sizeof(h_c) / sizeof(DataType); ++i) + { + reinterpret_cast(&h_c)[i] = static_cast(i + 100); + } + + typename MmaOp::AVecType* d_a; + typename MmaOp::BVecType* d_b; + typename MmaOp::CVecType* d_c; + typename MmaOp::CVecType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, sizeof(h_a))); + HIP_CHECK_ERROR(hipMalloc(&d_b, sizeof(h_b))); + HIP_CHECK_ERROR(hipMalloc(&d_c, sizeof(h_c))); + HIP_CHECK_ERROR(hipMalloc(&d_out, sizeof(h_out))); + + HIP_CHECK_ERROR(hipMemcpy(d_a, &h_a, sizeof(h_a), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, &h_b, sizeof(h_b), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, &h_c, sizeof(h_c), hipMemcpyHostToDevice)); + + test_amdgcn_mma_exec_kernel<<<1, 1>>>(d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(&h_out, d_out, sizeof(h_out), hipMemcpyDeviceToHost)); + + // Check that output matches expected: a + b + c + for(size_t i = 0; i < sizeof(h_out) / sizeof(DataType); ++i) + { + DataType expected = reinterpret_cast(&h_a)[i] + + reinterpret_cast(&h_b)[i] + + reinterpret_cast(&h_c)[i]; + EXPECT_EQ(reinterpret_cast(&h_out)[i], expected); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +} + +TEST(TestAmdgcnMma, ArchUnsupportedExecDeviceOutput) +{ + using MmaOp = DummyAmdgcnMma>; + using DataType = fp32_t; + + typename MmaOp::AVecType h_a{}; + typename MmaOp::BVecType h_b{}; + typename MmaOp::CVecType h_c{}; + typename MmaOp::CVecType h_out{}; + + // Fill C with known values + for(size_t i = 0; i < sizeof(h_c) / sizeof(DataType); ++i) + { + reinterpret_cast(&h_c)[i] = static_cast(i + 1); + } + + typename MmaOp::AVecType* d_a; + typename MmaOp::BVecType* d_b; + typename MmaOp::CVecType* d_c; + typename MmaOp::CVecType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, sizeof(h_a))); + HIP_CHECK_ERROR(hipMalloc(&d_b, sizeof(h_b))); + HIP_CHECK_ERROR(hipMalloc(&d_c, sizeof(h_c))); + HIP_CHECK_ERROR(hipMalloc(&d_out, sizeof(h_out))); + + HIP_CHECK_ERROR(hipMemcpy(d_a, &h_a, sizeof(h_a), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, &h_b, sizeof(h_b), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, &h_c, sizeof(h_c), hipMemcpyHostToDevice)); + + test_amdgcn_mma_exec_kernel<<<1, 1>>>(d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(&h_out, d_out, sizeof(h_out), hipMemcpyDeviceToHost)); + + // Check that output matches input C + for(size_t i = 0; i < sizeof(h_c) / sizeof(DataType); ++i) + { + EXPECT_EQ(reinterpret_cast(&h_out)[i], reinterpret_cast(&h_c)[i]); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +} + +#include "ck_tile/core/arch/mma/mma_traits.hpp" + +// Test MmaOpParams for supported DummyAmdgcnMma, including all member variables +TEST(TestAmdgcnMma, MmaOpParamsTraitsSupportedMembers) +{ + using MmaOp = DummyAmdgcnMma; + using Traits = MmaOpParams; + + // Check MmaOpParams members + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_EQ(Traits::BlockM, 16u); + EXPECT_EQ(Traits::BlockN, 16u); + EXPECT_EQ(Traits::BlockK, 16u); + EXPECT_TRUE((std::is_same::value)); +} + +// Test MmaOpParams for unsupported DummyAmdgcnMma, including all member variables +TEST(TestAmdgcnMma, MmaOpParamsUnsupportedMembers) +{ + using MmaOp = DummyAmdgcnMma>; + using Traits = MmaOpParams; + + // Check MmaOpParams members + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_EQ(Traits::BlockM, 16u); + EXPECT_EQ(Traits::BlockN, 16u); + EXPECT_EQ(Traits::BlockK, 16u); + EXPECT_TRUE((std::is_same::value)); +} + +// Test MmaOpTraits for supported DummyAmdgcnMma, including all member variables +TEST(TestAmdgcnMma, MmaOpTraitsSupportedMembers) +{ + using MmaOp = DummyAmdgcnMma; + using Traits = MmaOpTraits; + + // Check MmaOpTraits member variables + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_EQ(Traits::kAMBlock, 1); + EXPECT_EQ(Traits::kBNBlock, 2); + EXPECT_EQ(Traits::kAMLane, 3); + EXPECT_EQ(Traits::kBNLane, 4); + EXPECT_EQ(Traits::kABKLane, 5); + EXPECT_EQ(Traits::kABKPerLane, 6); + EXPECT_EQ(Traits::kCMLane, 7); + EXPECT_EQ(Traits::kCNLane, 8); + EXPECT_EQ(Traits::kCM0PerLane, 9); + EXPECT_EQ(Traits::kCM1PerLane, 10); + EXPECT_FALSE(Traits::IsMfma); + EXPECT_FALSE(Traits::IsWmma); + EXPECT_TRUE(Traits::IsSupported); +} + +// Test MmaOpTraits for unsupported DummyAmdgcnMma, including all member variables +TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers) +{ + using MmaOp = DummyAmdgcnMma>; + using Traits = MmaOpTraits; + + // Check MmaOpTraits member variables + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_EQ(Traits::kAMBlock, 0); + EXPECT_EQ(Traits::kBNBlock, 0); + EXPECT_EQ(Traits::kAMLane, 0); + EXPECT_EQ(Traits::kBNLane, 0); + EXPECT_EQ(Traits::kABKLane, 0); + EXPECT_EQ(Traits::kABKPerLane, 0); + EXPECT_EQ(Traits::kCMLane, 0); + EXPECT_EQ(Traits::kCNLane, 0); + EXPECT_EQ(Traits::kCM0PerLane, 0); + EXPECT_EQ(Traits::kCM1PerLane, 0); + EXPECT_FALSE(Traits::IsMfma); + EXPECT_FALSE(Traits::IsWmma); + EXPECT_FALSE(Traits::IsSupported); +} + +// Test MmaDefaultSelector for supported DummyAmdgcnMma +TEST(TestAmdgcnMma, MmaDefaultSelectorSupported) +{ + // Direct selection of the supported dummy instruction + using SelectedMma = + typename MmaDefaultSelector:: + SelectedOp; + // Should select DummyAmdgcnMma specialization + EXPECT_TRUE((std::is_same>::value)); + // OpType should be DummyOpType + EXPECT_TRUE((std::is_same::value)); + // IsSupported should be true + EXPECT_TRUE(MmaOpTraits::IsSupported); +} + +// Test MmaDefaultSelector for unsupported DummyAmdgcnMma +TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported) +{ + // Direct selection of the unsupported dummy instruction + using SelectedMma = + MmaDefaultSelector>::SelectedOp; + // OpType should be Unsupported + EXPECT_TRUE((std::is_same::value)); + // IsSupported should be false + EXPECT_FALSE(MmaOpTraits::IsSupported); +} + +// Test MmaDefaultSelector for supported DummyAmdgcnMma on fragment sizes other than 16x16x16 +// This tests that the selector can still pick the correct MMA op even if the fragment sizes differ +TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment) +{ + // Select indirectly with a fragment size of 256x128x64 + using SelectedMma = + MmaDefaultSelector:: + SelectedOp; + // Should select DummyAmdgcnMma specialization + EXPECT_TRUE((std::is_same>::value)); + // OpType should be DummyOpType + EXPECT_TRUE((std::is_same::value)); + // IsSupported should be true + EXPECT_TRUE(MmaOpTraits::IsSupported); +} + +// Test MmaDefaultSelector for a different block size and supported arch +TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment) +{ + // This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16 + using SelectedMma = + MmaDefaultSelector::SelectedOp; + EXPECT_FALSE((std::is_same::value)); + EXPECT_TRUE(MmaOpTraits::IsSupported); +} + +// Test MmaDefaultSelector for a different data type (fp16_t) and unsupported arch +TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported) +{ + using SelectedMma = + MmaDefaultSelector>::SelectedOp; + // Should select default amdgcn_mma (Unsupported) + EXPECT_TRUE((std::is_same::value)); + EXPECT_FALSE(MmaOpTraits::IsSupported); +} + +// Test on real hardware for MmaOp selection. +// This is not a GEMM kernel, but a simple test to ensure that the selected MmaOp works correctly on +// real hardware. Assumption: inputs are all 1's The multiply-accumulate functionality can be tested +// here by looping over the k dimension and accumulating the results. They should be equal to FragK +// regardless of hardware. +template +__global__ void test_accum_over_k(void* a, void* b, void* c, void* out) +{ + using Selector = MmaDefaultSelector; + + using MmaOp = typename Selector::SelectedOp; + using MmaTraits = MmaOpTraits; + + using CVecType = typename MmaOp::CVecType; + + static constexpr uint32_t kIters = FragK / MmaTraits::BlockK; + + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); + + // Accumulate input AxB over FragK/BlockK iterations + for(uint32_t i = 0; i < kIters; ++i) + { + result = MmaOp::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + result); + } + + *reinterpret_cast(out) = result; +} + +// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32. +TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) +{ + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + using AType = fp16_t; + using BType = fp16_t; + using CType = fp32_t; + + // Fragment size, also the expected block size from the selector. + // Note: Actual blockK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t FragM = 16; + static constexpr uint32_t FragN = 16; + static constexpr uint32_t FragK = 32; + static constexpr uint32_t BlockM = FragM; + static constexpr uint32_t BlockN = FragN; + static constexpr uint32_t BlockK = FragK; + + // Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1) + // TODO: c++20 use is_target_family_gfx11(currentArchId) + bool isGfx11 = (currentArchId >= amdgcn_target_id::GFX1100) && + (currentArchId <= amdgcn_target_id::GFX11_GENERIC); + uint32_t MultiplierA = isGfx11 ? 2 : 1; + uint32_t MultiplierB = isGfx11 ? 2 : 1; + uint32_t MultiplierC = 1; + + // The number of elements per thread + uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA; + uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB; + uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A and B to all 1's, C to all 0's + std::vector h_a(AElements, static_cast(1)); + std::vector h_b(BElements, static_cast(1)); + std::vector h_c(CElements, static_cast(0)); + std::vector h_out(CElements, static_cast(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + // Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour + test_accum_over_k<<<1, 64>>>(d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Output should be FragK for all elements, because the inputs are all 1's + for(size_t i = 0; i < CElements; ++i) + { + CType expected = static_cast(FragK); + + EXPECT_NEAR(h_out[i], expected, 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +} + +// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32 +// The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the +// fragment sizes are larger than 16x16x32. This tests that the selector can handle larger fragment +// sizes and still select the correct MmaOp. +TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) +{ + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + using AType = fp16_t; + using BType = fp16_t; + using CType = fp32_t; + + // Fragment size to test for decomposition. + // We expect the selector to pick a 16x16 block + static constexpr uint32_t FragM = 112; + static constexpr uint32_t FragN = 112; + static constexpr uint32_t FragK = 128; + + // The expected block size from the selector (multiple of 16). + // Note: Actual blockK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t BlockM = 16; + static constexpr uint32_t BlockN = 16; + static constexpr uint32_t BlockK = 32; + + // Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1) + // TODO: c++20 use is_target_family_gfx11(currentArchId) + bool isGfx11 = (currentArchId >= amdgcn_target_id::GFX1100) && + (currentArchId <= amdgcn_target_id::GFX11_GENERIC); + uint32_t MultiplierA = isGfx11 ? 2 : 1; + uint32_t MultiplierB = isGfx11 ? 2 : 1; + uint32_t MultiplierC = 1; + + // The number of elements per thread + uint32_t AElements = BlockM * BlockK / deviceWarpSize * MultiplierA; + uint32_t BElements = BlockN * BlockK / deviceWarpSize * MultiplierB; + uint32_t CElements = BlockM * BlockN / deviceWarpSize * MultiplierC; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A and B to all 1's, C to all 0's + std::vector h_a(AElements, static_cast(1)); + std::vector h_b(BElements, static_cast(1)); + std::vector h_c(CElements, static_cast(0)); + std::vector h_out(CElements, static_cast(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + // Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour + test_accum_over_k<<<1, 64>>>(d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Output should be FragK for all elements, because the inputs are all 1's + for(size_t i = 0; i < CElements; ++i) + { + CType expected = static_cast(FragK); + + EXPECT_NEAR(h_out[i], expected, 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +} diff --git a/test/ck_tile/core/arch/test_arch.cpp b/test/ck_tile/core/arch/test_arch.cpp new file mode 100644 index 0000000000..2d553c1595 --- /dev/null +++ b/test/ck_tile/core/arch/test_arch.cpp @@ -0,0 +1,396 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/core/arch/arch.hpp" + +using namespace ck_tile; +using namespace ck_tile::core::arch; + +// Test address_space_enum string conversion +TEST(TestArch, AddressSpaceToString) +{ + EXPECT_STREQ(address_space_to_string(address_space_enum::generic), "generic"); + EXPECT_STREQ(address_space_to_string(address_space_enum::global), "global"); + EXPECT_STREQ(address_space_to_string(address_space_enum::lds), "lds"); + EXPECT_STREQ(address_space_to_string(address_space_enum::sgpr), "sgpr"); + EXPECT_STREQ(address_space_to_string(address_space_enum::constant), "constant"); + EXPECT_STREQ(address_space_to_string(address_space_enum::vgpr), "vgpr"); + EXPECT_STREQ(address_space_to_string(static_cast(999)), "unknown"); +} + +#if 1 // __cplusplus <= 201703L + +// Tests make_amdgcn_gf9_target function +TEST(ArchTest, MakeGfx9TargetFields) +{ + constexpr auto target = make_amdgcn_gfx9_target(); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX908); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX9); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::CDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE64); +} + +// Tests make_amdgcn_gfx11_target function +TEST(ArchTest, MakeGfx11TargetFields) +{ + constexpr auto target = make_amdgcn_gfx11_target(); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX1100); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX11); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::RDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE32); +} + +// Tests make_amdgcn_gfx12_target function +TEST(ArchTest, MakeGfx12TargetFields) +{ + constexpr auto target = make_amdgcn_gfx12_target(); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX1200); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX12); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::RDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE32); +} + +// Tests default amdgcn_target +TEST(ArchTest, DefaultTargetIsHost) +{ + constexpr auto target = amdgcn_target<>{}; + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::HOST); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::HOST); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::HOST); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::HOST); +} + +// Tests get_compiler_target function on host +TEST(ArchTest, GetCompilerTargetDefaultIsHost) +{ + // By default, get_compiler_target should return HOST arch id because we aren't on device + auto target = get_compiler_target(); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::HOST); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::HOST); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::HOST); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::HOST); +} + +// SFINAE test setup for incoming acceptable target architecture ids +template +struct SFINAETestTargetIdGfx908OrGfx90a +{ + static constexpr bool value = false; +}; + +// Acceptable target arch ids: GFX908, GFX90A +template +struct SFINAETestTargetIdGfx908OrGfx90a< + Target, + enable_if_target_id_t> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target family ids +template +struct SFINAETestFamilyIdGfx9 +{ + static constexpr bool value = false; +}; + +// Acceptable target arch family ids: GFX9 +template +struct SFINAETestFamilyIdGfx9> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target architecture ids +template +struct SFINAETestArchIdCdna +{ + static constexpr bool value = false; +}; + +// Acceptable target arch ids: CDNA +template +struct SFINAETestArchIdCdna> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target wave size ids +template +struct SFINAETestWaveSizeIdWave64 +{ + static constexpr bool value = false; +}; + +// Acceptable target arch wave size ids: WAVE64 +template +struct SFINAETestWaveSizeIdWave64< + Target, + enable_if_target_wave_size_id_t> +{ + static constexpr bool value = true; +}; + +// Test SFINAE enablers with various architectures +TEST(ArchTest, TestSFINAEEnablersGfx9CdnaWave64) +{ + static constexpr auto target = make_amdgcn_gfx9_target(); + using Target = decltype(target); + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(true, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(true, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx11RdnaWave32) +{ + static constexpr auto target = make_amdgcn_gfx11_target(); + using Target = decltype(target); + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx12RdnaWave32) +{ + static constexpr auto target = make_amdgcn_gfx12_target(); + using Target = decltype(target); + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersHost) +{ + static constexpr auto target = amdgcn_target<>{}; + using Target = decltype(target); + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + // TODO: Should host be considered as WAVE64 or not? For now, we will consider it as WAVE64 + EXPECT_EQ(true, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx9CdnaWave32) +{ + static constexpr auto target = amdgcn_target{}; + using Target = decltype(target); + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(true, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersMix) +{ + static constexpr auto target = amdgcn_target{}; + using Target = decltype(target); + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +#elif 0 // TODO: c++20 tests + +// Tests make_amdgcn_gf9_target function +TEST(ArchTest, MakeGfx9TargetFields) +{ + constexpr auto target = make_amdgcn_gfx9_target(amdgcn_target_id::GFX908); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX908); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX9); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::CDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE64); +} + +// Tests make_amdgcn_gfx11_target function +TEST(ArchTest, MakeGfx11TargetFields) +{ + constexpr auto target = make_amdgcn_gfx11_target(amdgcn_target_id::GFX1100); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX1100); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX11); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::RDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE32); +} + +// Tests make_amdgcn_gfx12_target function +TEST(ArchTest, MakeGfx12TargetFields) +{ + constexpr auto target = make_amdgcn_gfx12_target(amdgcn_target_id::GFX1200); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::GFX1200); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::GFX12); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::RDNA); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::WAVE32); +} + +// Tests default amdgcn_target +TEST(ArchTest, DefaultTargetIsHost) +{ + constexpr amdgcn_target target{}; + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::HOST); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::HOST); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::HOST); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::HOST); +} + +// Tests get_compiler_target function on host +TEST(ArchTest, GetCompilerTargetDefaultIsHost) +{ + // By default, get_compiler_target should return HOST arch id because we aren't on device + auto target = get_compiler_target(); + EXPECT_EQ(target.TARGET_ID, amdgcn_target_id::HOST); + EXPECT_EQ(target.FAMILY_ID, amdgcn_target_family_id::HOST); + EXPECT_EQ(target.ARCH_ID, amdgcn_target_arch_id::HOST); + EXPECT_EQ(target.WAVE_SIZE_ID, amdgcn_target_wave_size_id::HOST); +} + +// SFINAE test setup for incoming acceptable target architecture ids +template +struct SFINAETestTargetIdGfx908OrGfx90a +{ + static constexpr bool value = false; +}; + +// Acceptable target arch ids: GFX908, GFX90A +template +struct SFINAETestTargetIdGfx908OrGfx90a< + Target, + enable_if_target_id_t> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target family ids +template +struct SFINAETestFamilyIdGfx9 +{ + static constexpr bool value = false; +}; + +// Acceptable target arch family ids: GFX9 +template +struct SFINAETestFamilyIdGfx9> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target architecture ids +template +struct SFINAETestArchIdCdna +{ + static constexpr bool value = false; +}; + +// Acceptable target arch ids: CDNA +template +struct SFINAETestArchIdCdna> +{ + static constexpr bool value = true; +}; + +// SFINAE test setup for incoming acceptable target wave size ids +template +struct SFINAETestWaveSizeIdWave64 +{ + static constexpr bool value = false; +}; + +// Acceptable target arch wave size ids: WAVE64 +template +struct SFINAETestWaveSizeIdWave64< + Target, + enable_if_target_wave_size_id_t> +{ + static constexpr bool value = true; +}; + +// Test SFINAE enablers with various architectures +TEST(ArchTest, TestSFINAEEnablersGfx9CdnaWave64) +{ + static constexpr auto target = + amdgcn_target{.TARGET_ID = amdgcn_target_id::GFX908, + .FAMILY_ID = amdgcn_target_family_id::GFX9, + .ARCH_ID = amdgcn_target_arch_id::CDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE64}; + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(true, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(true, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx11RdnaWave32) +{ + static constexpr auto target = + amdgcn_target{.TARGET_ID = amdgcn_target_id::GFX1100, + .FAMILY_ID = amdgcn_target_family_id::GFX11, + .ARCH_ID = amdgcn_target_arch_id::RDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx12RdnaWave32) +{ + static constexpr auto target = + amdgcn_target{.TARGET_ID = amdgcn_target_id::GFX1200, + .FAMILY_ID = amdgcn_target_family_id::GFX12, + .ARCH_ID = amdgcn_target_arch_id::RDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersHost) +{ + static constexpr auto target = amdgcn_target{.TARGET_ID = amdgcn_target_id::HOST, + .FAMILY_ID = amdgcn_target_family_id::HOST, + .ARCH_ID = amdgcn_target_arch_id::HOST, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::HOST}; + EXPECT_EQ(false, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(false, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersGfx9CdnaWave32) +{ + static constexpr auto target = + amdgcn_target{.TARGET_ID = amdgcn_target_id::GFX908, + .FAMILY_ID = amdgcn_target_family_id::GFX9, + .ARCH_ID = amdgcn_target_arch_id::CDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(true, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +TEST(ArchTest, TestSFINAEEnablersMix) +{ + static constexpr auto target = + amdgcn_target{.TARGET_ID = amdgcn_target_id::GFX90A, + .FAMILY_ID = amdgcn_target_family_id::GFX12, + .ARCH_ID = amdgcn_target_arch_id::CDNA, + .WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32}; + EXPECT_EQ(true, SFINAETestTargetIdGfx908OrGfx90a::value); + EXPECT_EQ(false, SFINAETestFamilyIdGfx9::value); + EXPECT_EQ(true, SFINAETestArchIdCdna::value); + EXPECT_EQ(false, SFINAETestWaveSizeIdWave64::value); +} + +#endif // __cplusplus <= 201703L