mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
First look at mfma / wmma unification (#2704)
* First look at mfma / wmma unification * Refactor * Re-org file structure * Restructure transform selection and WaveWiseMma class * Update license files. Add missing gfx1151 support. Change wave size for HOST to 1. Update datatypes naming consistency * Fixes default MmaSelector implentation * Adds unit tests for amdgcn_mma and arch * Consolidate common arch id checks to constexpr functions. Strongly type ids as amdgcn_target_arch_id object. * Refactor is_any_value_of * Fixes mma_selector logic * Fix typo * Add mma selector test for tile decomposition * Fix compilation of mma.hpp * Revert back to c++17 compatibility * Fix compiler error by returning index_t from get_warp_size() * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fixes compiler error for missing is_wave32() function * Fixes compiler error for host wave_size() should be 64 * Fixes compiler errors where __cpp_concepts is not defined * Fixes compiler errors where __cpp_concepts is not defined * Fix test failure for host is wave64 by default --------- Co-authored-by: Chris Millette <you@example.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8111572785
commit
b9c6cb1452
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
@@ -60,15 +61,753 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
add
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
namespace core::arch {
|
||||
|
||||
/**
|
||||
* @enum amdgcn_target_id
|
||||
* @brief Defines constants for AMDGCN architecture target IDs
|
||||
*/
|
||||
enum struct amdgcn_target_id
|
||||
{
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
GFX908 = 0x0908, // MI-100...
|
||||
GFX90A = 0x090A,
|
||||
GFX942 = 0x0942,
|
||||
GFX950 = 0x0950,
|
||||
GFX1030 = 0x1030,
|
||||
GFX1031 = 0x1031,
|
||||
GFX1032 = 0x1032,
|
||||
GFX1034 = 0x1034,
|
||||
GFX1035 = 0x1035,
|
||||
GFX1036 = 0x1036,
|
||||
GFX103_GENERIC = 0x103F,
|
||||
GFX1100 = 0x1100,
|
||||
GFX1101 = 0x1101,
|
||||
GFX1102 = 0x1102,
|
||||
GFX1103 = 0x1103,
|
||||
GFX1150 = 0x1150,
|
||||
GFX1151 = 0x1151,
|
||||
GFX1152 = 0x1152,
|
||||
GFX11_GENERIC = 0x11FF,
|
||||
GFX1200 = 0x1200,
|
||||
GFX1201 = 0x1201,
|
||||
GFX12_GENERIC = 0x12FF,
|
||||
HOST = 0x0000,
|
||||
};
|
||||
|
||||
enum struct amdgcn_target_family_id
|
||||
{
|
||||
GFX9 = 0x09,
|
||||
GFX10_3 = 0x10,
|
||||
GFX11 = 0x11,
|
||||
GFX12 = 0x12,
|
||||
HOST = 0x00,
|
||||
};
|
||||
|
||||
enum struct amdgcn_target_arch_id
|
||||
{
|
||||
CDNA = 0x01,
|
||||
RDNA = 0x02,
|
||||
HOST = 0x00,
|
||||
};
|
||||
|
||||
enum struct amdgcn_target_wave_size_id
|
||||
{
|
||||
WAVE32 = 32u,
|
||||
WAVE64 = 64u,
|
||||
HOST = 64u, // TODO: Is this correct? Should the host default to 64 or 1?
|
||||
};
|
||||
|
||||
#if 1 //__cplusplus <= 201703L
|
||||
|
||||
template <amdgcn_target_id TargetId = amdgcn_target_id::HOST,
|
||||
amdgcn_target_family_id FamilyId = amdgcn_target_family_id::HOST,
|
||||
amdgcn_target_arch_id ArchId = amdgcn_target_arch_id::HOST,
|
||||
amdgcn_target_wave_size_id WaveSizeId = amdgcn_target_wave_size_id::HOST>
|
||||
struct amdgcn_target
|
||||
{
|
||||
static constexpr amdgcn_target_id TARGET_ID = TargetId;
|
||||
static constexpr amdgcn_target_family_id FAMILY_ID = FamilyId;
|
||||
static constexpr amdgcn_target_arch_id ARCH_ID = ArchId;
|
||||
static constexpr amdgcn_target_wave_size_id WAVE_SIZE_ID = WaveSizeId;
|
||||
};
|
||||
|
||||
template <amdgcn_target_id targetId>
|
||||
static constexpr auto make_amdgcn_gfx9_target()
|
||||
{
|
||||
return amdgcn_target<targetId,
|
||||
amdgcn_target_family_id::GFX9,
|
||||
amdgcn_target_arch_id::CDNA,
|
||||
amdgcn_target_wave_size_id::WAVE64>{};
|
||||
}
|
||||
|
||||
template <amdgcn_target_id targetId>
|
||||
static constexpr auto make_amdgcn_gfx10_3_target()
|
||||
{
|
||||
return amdgcn_target<targetId,
|
||||
amdgcn_target_family_id::GFX10_3,
|
||||
amdgcn_target_arch_id::RDNA,
|
||||
amdgcn_target_wave_size_id::WAVE32>{};
|
||||
}
|
||||
|
||||
template <amdgcn_target_id targetId>
|
||||
static constexpr auto make_amdgcn_gfx11_target()
|
||||
{
|
||||
return amdgcn_target<targetId,
|
||||
amdgcn_target_family_id::GFX11,
|
||||
amdgcn_target_arch_id::RDNA,
|
||||
amdgcn_target_wave_size_id::WAVE32>{};
|
||||
}
|
||||
|
||||
template <amdgcn_target_id targetId>
|
||||
static constexpr auto make_amdgcn_gfx12_target()
|
||||
{
|
||||
return amdgcn_target<targetId,
|
||||
amdgcn_target_family_id::GFX12,
|
||||
amdgcn_target_arch_id::RDNA,
|
||||
amdgcn_target_wave_size_id::WAVE32>{};
|
||||
}
|
||||
|
||||
template <typename CompilerTarget, amdgcn_target_id... TargetIds>
|
||||
static constexpr auto is_target_id_any_of()
|
||||
{
|
||||
return is_any_value_of(CompilerTarget::TARGET_ID, TargetIds...);
|
||||
}
|
||||
|
||||
template <typename CompilerTarget, amdgcn_target_family_id... FamilyIds>
|
||||
static constexpr auto is_target_family_any_of()
|
||||
{
|
||||
return is_any_value_of(CompilerTarget::FAMILY_ID, FamilyIds...);
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_family_gfx9()
|
||||
{
|
||||
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX9;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_family_gfx10_3()
|
||||
{
|
||||
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX10_3;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_family_gfx11()
|
||||
{
|
||||
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX11;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_family_gfx12()
|
||||
{
|
||||
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX12;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_arch_cdna()
|
||||
{
|
||||
return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::CDNA;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_arch_rdna()
|
||||
{
|
||||
return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::RDNA;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_wave_size_32()
|
||||
{
|
||||
return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
|
||||
}
|
||||
|
||||
template <typename CompilerTarget>
|
||||
static constexpr bool is_target_wave_size_64()
|
||||
{
|
||||
return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
|
||||
}
|
||||
|
||||
// Helper to map compiler state to target arch id
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx9_target<amdgcn_target_id::TARGET_ID>(); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx10_3_target<amdgcn_target_id::TARGET_ID>(); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx11_target<amdgcn_target_id::TARGET_ID>(); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx12_target<amdgcn_target_id::TARGET_ID>(); \
|
||||
} \
|
||||
else
|
||||
|
||||
/**
|
||||
* @brief Returns the amdgcn_target of the current compiler pass.
|
||||
* @note This is where we tie the compiler state to our internal target architecture representation
|
||||
* at compile time.
|
||||
*/
|
||||
constexpr auto get_compiler_target()
|
||||
{
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
|
||||
|
||||
// Return HOST by default
|
||||
if constexpr(amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE)
|
||||
{
|
||||
return amdgcn_target<>{};
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
#undef MAP_COMPILER_STATE_TO_GFX9_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX11_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX12_TARGET
|
||||
|
||||
// Sanity check: device compile must have a valid target architecture
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
|
||||
get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
|
||||
"Device compile must have a valid target device architecture");
|
||||
|
||||
// Sanity check: host compile must have HOST target architecture
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
|
||||
get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
|
||||
"Host compile must target HOST architecture");
|
||||
|
||||
// TODO: c++20 use the make functions and constexpr if to avoid string construction and find at
|
||||
// runtime
|
||||
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID(NAME_STRING, TARGET_ID) \
|
||||
if(str.find(NAME_STRING) != std::string::npos) \
|
||||
{ \
|
||||
return amdgcn_target_id::TARGET_ID; \
|
||||
} \
|
||||
else
|
||||
|
||||
/**
|
||||
* @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value.
|
||||
* Returns amdgcn_target_arch_id::HOST if no match is found.
|
||||
* Matches if the input contains the architecture substring.
|
||||
* Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info.
|
||||
*/
|
||||
// TODO: c++20 constexpr if and string_view to avoid std::string construction and find at runtime
|
||||
// TODO: c++20 return amdgcn_target instance instead of just the target id
|
||||
CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* testStr)
|
||||
{
|
||||
auto str = std::string(testStr);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx908", GFX908);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx90a", GFX90A);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx942", GFX942);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx950", GFX950);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx10_3_generic", GFX103_GENERIC);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1100", GFX1100);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1101", GFX1101);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1102", GFX1102);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1103", GFX1103);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx12_generic", GFX12_GENERIC);
|
||||
|
||||
// Default case: return HOST target if no match is found
|
||||
return amdgcn_target_id::HOST;
|
||||
}
|
||||
|
||||
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the target id is in the list of supported target
|
||||
* ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908
|
||||
*/
|
||||
template <typename CompilerTarget, amdgcn_target_id... SupportedTargetIds>
|
||||
using enable_if_target_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, SupportedTargetIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the family id is in the list of supported family
|
||||
* ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetFamilyIds The list of supported family ids, e.g.,
|
||||
* amdgcn_target_family_id::GFX9
|
||||
*/
|
||||
template <typename CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
|
||||
using enable_if_target_family_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget::FAMILY_ID, SupportedTargetFamilyIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA
|
||||
*/
|
||||
template <typename CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
|
||||
using enable_if_target_arch_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget::ARCH_ID, SupportedTargetArchIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave
|
||||
* size ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g.,
|
||||
* amdgcn_target_wave_size_id::WAVE64
|
||||
*/
|
||||
template <typename CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
|
||||
using enable_if_target_wave_size_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget::WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
|
||||
|
||||
/// Specialized enablers for common families, architectures, and wave sizes ///
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX9 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_family_gfx9_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX10.3 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_family_gfx10_3_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX11 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_family_gfx11_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX12 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_family_gfx12_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for CDNA architecture targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_arch_cdna_t =
|
||||
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for RDNA architecture targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_arch_rdna_t =
|
||||
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for WAVE32 size targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_wave32_t =
|
||||
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for WAVE64 size targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <typename CompilerTarget>
|
||||
using enable_if_target_wave64_t =
|
||||
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
|
||||
|
||||
#elif __cplusplus >= 202002L
|
||||
|
||||
struct amdgcn_target
|
||||
{
|
||||
// Target architecture identifiers
|
||||
// These are set to HOST (0) by default
|
||||
// TARGET_ID is the specific architecture id (e.g., GFX908)
|
||||
// FAMILY_ID is the architecture family id (e.g., GFX9)
|
||||
// ARCH_ID is the architecture class id (e.g., CDNA, RDNA)
|
||||
// WAVE_SIZE_ID is the wavefront size id (e.g., WAVE32, WAVE64)
|
||||
const amdgcn_target_id TARGET_ID = amdgcn_target_id::HOST;
|
||||
const amdgcn_target_family_id FAMILY_ID = amdgcn_target_family_id::HOST;
|
||||
const amdgcn_target_arch_id ARCH_ID = amdgcn_target_arch_id::HOST;
|
||||
const amdgcn_target_wave_size_id WAVE_SIZE_ID = amdgcn_target_wave_size_id::HOST;
|
||||
};
|
||||
|
||||
static constexpr auto make_amdgcn_gfx10_3_target(amdgcn_target_id targetId)
|
||||
{
|
||||
return amdgcn_target{.TARGET_ID = targetId,
|
||||
.FAMILY_ID = amdgcn_target_family_id::GFX10_3,
|
||||
.ARCH_ID = amdgcn_target_arch_id::RDNA,
|
||||
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
|
||||
}
|
||||
|
||||
static constexpr auto make_amdgcn_gfx9_target(amdgcn_target_id targetId)
|
||||
{
|
||||
return amdgcn_target{.TARGET_ID = targetId,
|
||||
.FAMILY_ID = amdgcn_target_family_id::GFX9,
|
||||
.ARCH_ID = amdgcn_target_arch_id::CDNA,
|
||||
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE64};
|
||||
}
|
||||
|
||||
static constexpr auto make_amdgcn_gfx11_target(amdgcn_target_id targetId)
|
||||
{
|
||||
return amdgcn_target{.TARGET_ID = targetId,
|
||||
.FAMILY_ID = amdgcn_target_family_id::GFX11,
|
||||
.ARCH_ID = amdgcn_target_arch_id::RDNA,
|
||||
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
|
||||
}
|
||||
|
||||
static constexpr auto make_amdgcn_gfx12_target(amdgcn_target_id targetId)
|
||||
{
|
||||
return amdgcn_target{.TARGET_ID = targetId,
|
||||
.FAMILY_ID = amdgcn_target_family_id::GFX12,
|
||||
.ARCH_ID = amdgcn_target_arch_id::RDNA,
|
||||
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
|
||||
}
|
||||
|
||||
static constexpr bool is_target_family_gfx9(amdgcn_target target)
|
||||
{
|
||||
return target.FAMILY_ID == amdgcn_target_family_id::GFX9;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_family_gfx10_3(amdgcn_target target)
|
||||
{
|
||||
return target.FAMILY_ID == amdgcn_target_family_id::GFX10_3;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_family_gfx11(amdgcn_target target)
|
||||
{
|
||||
return target.FAMILY_ID == amdgcn_target_family_id::GFX11;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_family_gfx12(amdgcn_target target)
|
||||
{
|
||||
return target.FAMILY_ID == amdgcn_target_family_id::GFX12;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_arch_cdna(amdgcn_target target)
|
||||
{
|
||||
return target.ARCH_ID == amdgcn_target_arch_id::CDNA;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_arch_rdna(amdgcn_target target)
|
||||
{
|
||||
return target.ARCH_ID == amdgcn_target_arch_id::RDNA;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_wave_size_32(amdgcn_target target)
|
||||
{
|
||||
return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
|
||||
}
|
||||
|
||||
static constexpr bool is_target_wave_size_64(amdgcn_target target)
|
||||
{
|
||||
return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
|
||||
}
|
||||
|
||||
// Helper to map compiler state to target arch id
|
||||
#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
|
||||
}
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
|
||||
}
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
|
||||
}
|
||||
|
||||
#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
|
||||
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
|
||||
{ \
|
||||
return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
|
||||
}
|
||||
|
||||
/*! @brief Returns the amdgcn_target of the current compiler pass.
|
||||
* @note This is where we tie the compiler state to our internal target architecture representation
|
||||
* at compile time.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
|
||||
{
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
|
||||
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
|
||||
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
|
||||
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
|
||||
|
||||
// Default to HOST
|
||||
return amdgcn_target{};
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
#undef MAP_COMPILER_STATE_TO_GFX9_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX11_TARGET
|
||||
#undef MAP_COMPILER_STATE_TO_GFX12_TARGET
|
||||
|
||||
// Sanity check: device compile must have a valid target architecture
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
|
||||
get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
|
||||
"Device compile must have a valid target device architecture");
|
||||
|
||||
// Sanity check: host compile must have HOST target architecture
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
|
||||
get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
|
||||
"Host compile must target HOST architecture");
|
||||
|
||||
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET(NAME_STRING, TARGET_ID) \
|
||||
if constexpr(str.find(NAME_STRING) != std::string::npos) \
|
||||
{ \
|
||||
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET(NAME_STRING, TARGET_ID) \
|
||||
if constexpr(str.find(NAME_STRING) != std::string::npos) \
|
||||
{ \
|
||||
return make_amdgcn_gfx10_3_target(amdgcn_target_id::TARGET_ID); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET(NAME_STRING, TARGET_ID) \
|
||||
if constexpr(str.find(NAME_STRING) != std::string::npos) \
|
||||
{ \
|
||||
return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
|
||||
} \
|
||||
else
|
||||
|
||||
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET(NAME_STRING, TARGET_ID) \
|
||||
if constexpr(str.find(NAME_STRING) != std::string::npos) \
|
||||
{ \
|
||||
return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
|
||||
} \
|
||||
else
|
||||
|
||||
/**
|
||||
* @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value.
|
||||
* Returns amdgcn_target_arch_id::HOST if no match is found.
|
||||
* Matches if the input contains the architecture substring.
|
||||
* Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info.
|
||||
*/
|
||||
CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* testStr)
|
||||
{
|
||||
auto str = std::string(testStr);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx908", GFX908);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx90a", GFX90A);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx942", GFX942);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx950", GFX950);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx10_3_generic", GFX103_GENERIC);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1100", GFX1100);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1101", GFX1101);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1102", GFX1102);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1103", GFX1103);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx12_generic", GFX12_GENERIC);
|
||||
|
||||
// Default case
|
||||
return amdgcn_target{};
|
||||
}
|
||||
|
||||
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET
|
||||
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET
|
||||
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET
|
||||
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the target id is in the list of supported target
|
||||
* ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget, amdgcn_target_id... SupportedTargetIds>
|
||||
using enable_if_target_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget.TARGET_ID, SupportedTargetIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the family id is in the list of supported family
|
||||
* ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetFamilyIds The list of supported family ids, e.g.,
|
||||
* amdgcn_target_family_id::GFX9
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
|
||||
using enable_if_target_family_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget.FAMILY_ID, SupportedTargetFamilyIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
|
||||
using enable_if_target_arch_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget.ARCH_ID, SupportedTargetArchIds...)>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave
|
||||
* size ids
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
* @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g.,
|
||||
* amdgcn_target_wave_size_id::WAVE64
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
|
||||
using enable_if_target_wave_size_id_t =
|
||||
std::enable_if_t<is_any_value_of(CompilerTarget.WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
|
||||
|
||||
/// Specialized enablers for common families, architectures, and wave sizes ///
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX9 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_family_gfx9_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX10.3 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_family_gfx10_3_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX11 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_family_gfx11_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for GFX12 family targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_family_gfx12_t =
|
||||
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for CDNA architecture targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_arch_cdna_t =
|
||||
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for RDNA architecture targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_arch_rdna_t =
|
||||
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for WAVE32 size targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_wave32_t =
|
||||
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
|
||||
|
||||
/**
|
||||
* @brief SFINAE enabler for WAVE64 size targets
|
||||
* @tparam CompilerTarget The compiler target to check
|
||||
*/
|
||||
template <amdgcn_target CompilerTarget>
|
||||
using enable_if_target_wave64_t =
|
||||
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
|
||||
|
||||
#endif // __cplusplus <= 201703L
|
||||
|
||||
} // namespace core::arch
|
||||
|
||||
CK_TILE_HOST bool is_wave32()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
@@ -86,6 +825,13 @@ CK_TILE_HOST bool is_wave32()
|
||||
return props.major > 9;
|
||||
}
|
||||
|
||||
/*! @brief Returns the amdgcn_wave_size of the current compiler pass
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
return static_cast<index_t>(core::arch::get_compiler_target().WAVE_SIZE_ID);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
|
||||
|
||||
118
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
118
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct Unsupported
|
||||
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
|
||||
*/
|
||||
struct Unsupported;
|
||||
|
||||
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
/**
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
typename MmaOp::OpType;
|
||||
|
||||
// Captures types for inputs / outputs to mma function
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
|
||||
|
||||
// Static exec function
|
||||
{
|
||||
MmaOp::exec(
|
||||
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
|
||||
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @class amdgcn_mma
|
||||
* @brief This is the default MmaOp policy.
|
||||
* Instances of this class are to be used as MmaOp policies.
|
||||
* Light builtin wrapper for mfma / wmma instructions. This class's job is to
|
||||
* provide a uniform interface to invoke the appropriate instruction
|
||||
* based on the template parameters provided. This interface is to bridge
|
||||
* the gap between the ck_tile API types and the native __builtin types.
|
||||
* @tparam ADataType Datatype of input A
|
||||
* @tparam BDataType Datatype of input B
|
||||
* @tparam CDataType Datatype of accumulator
|
||||
* @tparam BlockM M-dimension of mma block
|
||||
* @tparam BlockN N-dimension of mma block
|
||||
* @tparam BlockK K-dimension of mma block
|
||||
* @tparam CtrlFlags Control flags for mma operation
|
||||
* @tparam CompilerTarget The current compiler target
|
||||
* @tparam Enabler SFINAE enabler
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
typename Enabler = void>
|
||||
struct amdgcn_mma
|
||||
{
|
||||
// The base instance is unsupported because there is no __builtin to wrap.
|
||||
using OpType = Unsupported;
|
||||
|
||||
// Interface types for A, B, C vectors types
|
||||
using AVecType = ext_vector_t<ADataType, 1>;
|
||||
using BVecType = ext_vector_t<BDataType, 1>;
|
||||
using CVecType = ext_vector_t<CDataType, 1>;
|
||||
|
||||
// Layout constants - default to 0
|
||||
static constexpr index_t kAMBlock = 0;
|
||||
static constexpr index_t kBNBlock = 0;
|
||||
|
||||
static constexpr index_t kAMLane = 0;
|
||||
static constexpr index_t kBNLane = 0;
|
||||
static constexpr index_t kABKLane = 0;
|
||||
static constexpr index_t kABKPerLane = 0;
|
||||
|
||||
static constexpr index_t kCMLane = 0;
|
||||
static constexpr index_t kCNLane = 0;
|
||||
static constexpr index_t kCM0PerLane = 0;
|
||||
static constexpr index_t kCM1PerLane = 0;
|
||||
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
CK_TILE_DEVICE static CVecType const&
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
ignore(regsA, regsB);
|
||||
return regsC; // No-op, just return C
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "mfma/mfma.hpp"
|
||||
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include the architecture-specific MFMA implementations and traits
|
||||
#include "mfma_gfx9.hpp"
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_selector.hpp"
|
||||
#include "mfma_transforms.hpp"
|
||||
162
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
162
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct DefaultMmaCtrlFlags
|
||||
* @brief Default MFMA flags, no broadcasting or rotation of inputs
|
||||
*/
|
||||
struct DefaultMfmaCtrlFlags
|
||||
{
|
||||
static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
|
||||
static constexpr uint32_t Abid = 0; // ABID flag, default 0
|
||||
static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
|
||||
};
|
||||
|
||||
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @concept CtrlFlagsGfx9I
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for Gfx9 MFMA instructions
|
||||
{ CtrlFlags::Cbsz } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Abid } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Blgp } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = MfmaOp;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
|
||||
// Packed register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
189
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
189
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_gfx9.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MfmaDefaultSelector
|
||||
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported MFMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through
|
||||
implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam BlockKTest Current Block K dimension size to test
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @note Here we assume that BlockKTest is always a power-of-two integer.
|
||||
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
|
||||
* each time.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct MfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget>;
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MfmaDefaultSelector
|
||||
* @brief Implements the base case for the default MFMA selector when no supported instruction is
|
||||
* found.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported MFMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp4x4 =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 4u, 4u, 4u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
using CandidateOp16x16 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported4x4 =
|
||||
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<
|
||||
IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16,
|
||||
CandidateOp16x16,
|
||||
std::conditional_t<IsSupported4x4, CandidateOp4x4, DefaultOp>>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MfmaOp
|
||||
* @brief Meta-tag for the MFMA operation. This will be used in the MmaOp policies to
|
||||
* identify the operation as an MFMA operation.
|
||||
*/
|
||||
struct MfmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_mfma
|
||||
* @brief Trait to check if MmaOp is an MFMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_mfma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_mfma
|
||||
* @brief MmaOp specialization for MFMA operations, confirming the OpType matches MfmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, MfmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_mfma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx9
|
||||
* @brief Implements the default MMA transforms for gfx9 targets
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx9
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx9 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx9;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
234
include/ck_tile/core/arch/mma/mma.hpp
Normal file
234
include/ck_tile/core/arch/mma/mma.hpp
Normal file
@@ -0,0 +1,234 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_traits.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "wmma/wmma.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaAccumPolicy
|
||||
* @brief Accumulation order for Mma decomposition
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major block order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major block order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
|
||||
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
|
||||
* (C: FragM x FragN).
|
||||
* @tparam ADataType Data type of input fragment A
|
||||
* @tparam BDataType Data type of input fragment B
|
||||
* @tparam CDataType Data type of input/output fragment C (accumulator)
|
||||
* @tparam FragM Mma fragment M dimension
|
||||
* @tparam FragN Mma fragment K dimension
|
||||
* @tparam FragK Mma fragment M dimension
|
||||
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
|
||||
* order)
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
|
||||
* wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
{
|
||||
|
||||
using BlockWiseMmaOp = MmaOp;
|
||||
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
|
||||
|
||||
// Block dimensions
|
||||
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
|
||||
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
|
||||
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
|
||||
|
||||
// Block counts for decomposition
|
||||
constexpr static uint32_t BlocksM = FragM / BlockM;
|
||||
constexpr static uint32_t BlocksN = FragN / BlockN;
|
||||
constexpr static uint32_t BlocksK = FragK / BlockK;
|
||||
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
|
||||
|
||||
// Vector types for packed registers in each block
|
||||
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
|
||||
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
|
||||
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
|
||||
|
||||
// Buffer types for fragments
|
||||
using ABufferType = AVecType[BlocksM][BlocksK];
|
||||
using BBufferType = BVecType[BlocksN][BlocksK];
|
||||
using CBufferType = CVecType[BlocksM][BlocksN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
|
||||
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
|
||||
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
|
||||
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
|
||||
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
|
||||
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
|
||||
|
||||
private:
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT const&>(inputBuffer);
|
||||
}
|
||||
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT&>(inputBuffer);
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Col-major" accumulation over the M-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in col-major order
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Row-major" accumulation over the N-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
|
||||
// We also have to ensure that the incoming vector fragments are converted to native vector
|
||||
// types before passing to the BlockWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
public:
|
||||
/*! @brief Forward to Mma operation with specified accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
|
||||
{
|
||||
return exec_row_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
|
||||
{
|
||||
return exec_col_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
63
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
63
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MmaDefaultSelector
|
||||
* @brief Implements a default mma selector strategy for the current target architecture.
|
||||
* This is simply intended as a default selection strategy for mma instruction operations.
|
||||
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
|
||||
* select the instruction with the largest K dimension that is supported on the current target
|
||||
* architecture.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Fragment M dimension
|
||||
* @tparam FragN Fragment N dimension
|
||||
* @tparam FragK Fragment K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam Enable SFINAE enabler
|
||||
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
|
||||
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
|
||||
* correspond to the size of the individual MMA instructions being used to compute the overall in
|
||||
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
|
||||
* equal to the Block sizes.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
typename Enable = void>
|
||||
// TODO c++20 requires
|
||||
struct MmaDefaultSelector
|
||||
{
|
||||
// By default, no selection is made, and we fall back to a pass-through unsupported
|
||||
// implementation. This is because we do not have any knowledge of the target architecture here.
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<>>;
|
||||
};
|
||||
|
||||
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @concept MmaSelectorI
|
||||
* @brief Expresses the required members for each MmaSelector class.
|
||||
*/
|
||||
template <typename MmaSelector>
|
||||
concept MmaSelectorI = requires(MmaSelector op) {
|
||||
// Selectors should have a resulting SelectedOp type
|
||||
typename MmaSelector::SelectedOp;
|
||||
};
|
||||
|
||||
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma_selector.hpp"
|
||||
#include "mfma/mfma_selector.hpp"
|
||||
151
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
151
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mfma/mfma_traits.hpp"
|
||||
#include "wmma/wmma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class is_mma_op_supported
|
||||
* @brief Trait to check if MmaOp is supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, typename = void>
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_supported : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_supported
|
||||
* @brief The MmaOp is unsupported specialization
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_supported<MmaOp,
|
||||
std::enable_if_t<std::is_same_v<typename MmaOp::OpType, Unsupported>>>
|
||||
: std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluation of is_mma_op_supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @class MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
struct MmaOpParams;
|
||||
|
||||
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @concept MmaOpParamsI
|
||||
* @brief Expresses the required members for each MmaOp
|
||||
*/
|
||||
template <typename MmaOpParams>
|
||||
concept MmaOpParamsI = requires(MmaOpParams op) {
|
||||
// Capture template parameters
|
||||
typename MmaOpParams::ADataType;
|
||||
typename MmaOpParams::BDataType;
|
||||
typename MmaOpParams::CDataType;
|
||||
typename MmaOpParams::CtrlFlags;
|
||||
|
||||
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
|
||||
};
|
||||
|
||||
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @struct MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam ADataType_ Data type of matrix A
|
||||
* @tparam BDataType_ Data type of matrix B
|
||||
* @tparam CDataType_ Data type of the accumulator
|
||||
* @tparam BlockM_ Size of the M dimension
|
||||
* @tparam BlockN_ Size of the N dimension
|
||||
* @tparam BlockK_ Size of the K dimension
|
||||
* @tparam CtrlFlags_ Control flags for the MMA operation
|
||||
* @tparam CompilerTarget_ The compiler target
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
uint32_t BlockM_,
|
||||
uint32_t BlockN_,
|
||||
uint32_t BlockK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockM_,
|
||||
BlockN_,
|
||||
BlockK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_>>
|
||||
{
|
||||
// Capture incoming template parameters
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaOpTraits
|
||||
* @brief Reflects the template parameters and static members of a given MmaOp.
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
|
||||
struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
{
|
||||
// Capture internal MmaOp static members
|
||||
using OpType = typename MmaOp::OpType;
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
// Capture layout parameters
|
||||
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
|
||||
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
|
||||
static constexpr index_t kAMLane = MmaOp::kAMLane;
|
||||
static constexpr index_t kBNLane = MmaOp::kBNLane;
|
||||
static constexpr index_t kABKLane = MmaOp::kABKLane;
|
||||
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
|
||||
static constexpr index_t kCMLane = MmaOp::kCMLane;
|
||||
static constexpr index_t kCNLane = MmaOp::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
|
||||
|
||||
// Additional traits to identify the type of MmaOp at compile time
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsSupported = is_mma_op_supported_v<MmaOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
48
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
48
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct PassThroughTransform
|
||||
* @brief A no-op transform that passes through the input as-is.
|
||||
*/
|
||||
struct PassThroughTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
|
||||
* @tparam MmaOp The Mma operation type
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam Enable SFINAE parameter for specialization
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
|
||||
struct MmaTransformsDefaultSelector;
|
||||
|
||||
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
/**
|
||||
* @concept MmaTransformsI
|
||||
* @brief Expresses the interface of required members for each MmaTransforms type.
|
||||
*/
|
||||
template <typename MmaTransforms>
|
||||
concept MmaTransformsI = requires(MmaTransforms transforms) {
|
||||
// Transforms should define TransformA, TransformB, TransformC, and TransformD types
|
||||
typename MmaTransforms::ATransform;
|
||||
typename MmaTransforms::BTransform;
|
||||
typename MmaTransforms::CTransform;
|
||||
typename MmaTransforms::DTransform;
|
||||
};
|
||||
|
||||
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum WmmaCtrlFlags
|
||||
* @brief Common wmma control flags for gfx11 and gfx12
|
||||
*/
|
||||
enum struct WmmaCtrlFlags : bool
|
||||
{
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit
|
||||
// Determines which half of the 32-bit accum register to use
|
||||
// Low = bits [15:0]
|
||||
// High = bits[31:16]
|
||||
LOW = false,
|
||||
HIGH = true,
|
||||
|
||||
// Only has an effect on gfx11 / 12 when the input is 8-bit int
|
||||
// Signage indicator of inputs / accum
|
||||
UNSIGNED = false,
|
||||
SIGNED = true
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the architecture-specific WMMA implementations and traits
|
||||
#include "wmma_gfx11.hpp"
|
||||
#include "wmma_gfx12.hpp"
|
||||
#include "wmma_selector.hpp"
|
||||
#include "wmma_traits.hpp"
|
||||
#include "wmma_transforms.hpp"
|
||||
109
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
109
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
|
||||
// - Duplicating A and B inputs
|
||||
// - Handling C / D is always in b32, even for f16 accumulation.
|
||||
// NOTE: Two suggestions:
|
||||
// 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to
|
||||
// duplicate the inputs as needed before calling the actual built-in. This may introduce
|
||||
// some instruction overhead and violate single responsibility clauses, but keeps the logic
|
||||
// contained within the backend wrapper.
|
||||
// 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing
|
||||
// pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform
|
||||
// post-mma could implement D-shuffle transform. This may be cleaner and more flexible than
|
||||
// trying to handle everything in the backend wrappers.
|
||||
//
|
||||
// This current example assumes duplication has already been done, and that C data shuffles have
|
||||
// already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled
|
||||
// data in C.
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @class DefaultWmmaFlags
|
||||
* @brief Generates default WMMA control flags based on data types.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename CDataType>
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
// Generate default flags for signage
|
||||
// Only used currently for integer inputs / accum in gfx11 / gfx12
|
||||
constexpr static WmmaCtrlFlags InputSignA =
|
||||
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags InputSignB =
|
||||
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags AccumSign =
|
||||
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
|
||||
// Generate default flags for accumulator destination bits.
|
||||
// Only used if accumulation size is 16-bit in gfx11
|
||||
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
|
||||
// Register types (duplicated input / b32 accum)
|
||||
using AVecType = ext_vector_t<fp16_t, 16>;
|
||||
using BVecType = ext_vector_t<fp16_t, 16>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
69
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
69
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
161
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
161
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported WMMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct WmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension == 1, which is the base case for the recursive K dimension
|
||||
* search. If no supported instruction is found, falls back to an unsupported pass-through
|
||||
* implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
|
||||
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CtrlFlags, CompilerTarget>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported WMMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
enable_if_target_arch_rdna_t<CompilerTarget>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct WmmaOp
|
||||
* @brief Meta-tag for the WMMA operation. This will be used in the MmaOp struct to
|
||||
* identify the operation as an WMMA operation.
|
||||
*/
|
||||
struct WmmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_wmma
|
||||
* @brief Trait to check if MmaOp is an WMMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_wmma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_wmma
|
||||
* @brief MmaOp specialization for WMMA operations, confirming the OpType matches WmmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, WmmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_wmma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DuplicateTransform
|
||||
* @brief Transform to duplicate low register elements to high register elements
|
||||
*/
|
||||
struct DuplicateTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement duplication logic to broadcast low
|
||||
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct PadTransform
|
||||
* @brief Transform to pad data from original type to b32 type
|
||||
*/
|
||||
struct PadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 padding logic.
|
||||
// E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct UnpadTransform
|
||||
* @brief Transform to unpad data from b32 type to original type
|
||||
*/
|
||||
struct UnpadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 logic to unpad 32 to original data type.
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx11
|
||||
* @brief Default MMA transforms for GFX11 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx11
|
||||
{
|
||||
using ATransform = DuplicateTransform;
|
||||
using BTransform = DuplicateTransform;
|
||||
using CTransform = PadTransform;
|
||||
using DTransform = UnpadTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx12
|
||||
* @brief Default MMA transforms for GFX12 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx12
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx11 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx11;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx12 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx12;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -284,3 +284,242 @@
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
|
||||
#endif
|
||||
|
||||
// Mark unsupported features with a deprecation warning in debug builds
|
||||
#if defined(NDEBUG)
|
||||
#define CK_TILE_UNSUPPORTED_IMPL(MSG)
|
||||
#else
|
||||
#define CK_TILE_UNSUPPORTED_IMPL(MSG) __attribute__((deprecated(MSG)))
|
||||
#endif
|
||||
|
||||
namespace ck_tile::core {
|
||||
/**
|
||||
* @struct amdgcn_compiler_target_state
|
||||
* @brief Defines compiler states for supported AMDGCN devices.
|
||||
* @var CK_TILE_HOST_COMPILE Indicates if the compilation is for the host.
|
||||
* @var CK_TILE_DEVICE_COMPILE Indicates if the compilation is for AMDGCN device.
|
||||
* @var CK_TILE_ARCH_GFX908 Indicates if the compiler target architecture is GFX908.
|
||||
* @var CK_TILE_ARCH_GFX90A Indicates if the compiler target architecture is GFX90A.
|
||||
* @var CK_TILE_ARCH_GFX942 Indicates if the compiler target architecture is GFX942.
|
||||
* @var CK_TILE_ARCH_GFX950 Indicates if the compiler target architecture is GFX950.
|
||||
* @var CK_TILE_ARCH_GFX1030 Indicates if the compiler target architecture is GFX1030.
|
||||
* @var CK_TILE_ARCH_GFX1031 Indicates if the compiler target architecture is GFX1031.
|
||||
* @var CK_TILE_ARCH_GFX1032 Indicates if the compiler target architecture is GFX1032.
|
||||
* @var CK_TILE_ARCH_GFX1034 Indicates if the compiler target architecture is GFX1034.
|
||||
* @var CK_TILE_ARCH_GFX1035 Indicates if the compiler target architecture is GFX1035.
|
||||
* @var CK_TILE_ARCH_GFX1036 Indicates if the compiler target architecture is GFX1036.
|
||||
* @var CK_TILE_ARCH_GFX10_3_GENERIC Indicates if the compiler target architecture is GFX10.3
|
||||
* generic.
|
||||
* @var CK_TILE_ARCH_GFX1100 Indicates if the compiler target architecture is GFX1100.
|
||||
* @var CK_TILE_ARCH_GFX1101 Indicates if the compiler target architecture is GFX1101.
|
||||
* @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102.
|
||||
* @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151.
|
||||
* @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152.
|
||||
* @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic.
|
||||
* @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200.
|
||||
* @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201.
|
||||
* @var CK_TILE_ARCH_GFX12_GENERIC Indicates if the compiler target architecture is GFX12 generic.
|
||||
*/
|
||||
struct amdgcn_compiler_target_state
|
||||
{
|
||||
// Determine if we are compiling for device or host
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
|
||||
static constexpr bool CK_TILE_DEVICE_COMPILE = true;
|
||||
static constexpr bool CK_TILE_HOST_COMPILE = false;
|
||||
#else
|
||||
static constexpr bool CK_TILE_DEVICE_COMPILE = false;
|
||||
static constexpr bool CK_TILE_HOST_COMPILE = true;
|
||||
#endif // __HIP_DEVICE_COMPILE__ && __HIP_DEVICE_COMPILE__
|
||||
|
||||
// GFX9
|
||||
#if defined(__gfx908__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX908 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX908 = false;
|
||||
#endif // __gfx908__
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX90A = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX90A = false;
|
||||
#endif // __gfx90a__
|
||||
|
||||
#if defined(__gfx942__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX942 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX942 = false;
|
||||
#endif // __gfx942__
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX950 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX950 = false;
|
||||
#endif // __gfx950__
|
||||
|
||||
// GFX10
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = false;
|
||||
#endif // __gfx1030__
|
||||
|
||||
#if defined(__gfx1031__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1031 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1031 = false;
|
||||
#endif // __gfx1031__
|
||||
|
||||
#if defined(__gfx1032__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1032 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1032 = false;
|
||||
#endif // __gfx1032__
|
||||
|
||||
#if defined(__gfx1034__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1034 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1034 = false;
|
||||
#endif // __gfx1034__
|
||||
|
||||
#if defined(__gfx1035__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1035 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1035 = false;
|
||||
#endif // __gfx1035__
|
||||
|
||||
#if defined(__gfx1036__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1036 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1036 = false;
|
||||
#endif // __gfx1036__
|
||||
|
||||
#if defined(__gfx10_3_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = false;
|
||||
#endif // __gfx10_3_generic__
|
||||
|
||||
// GFX11
|
||||
#if defined(__gfx1100__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1100 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1100 = false;
|
||||
#endif // __gfx1100__
|
||||
|
||||
#if defined(__gfx1101__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1101 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1101 = false;
|
||||
#endif // __gfx1101__
|
||||
|
||||
#if defined(__gfx1102__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1102 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1102 = false;
|
||||
#endif // __gfx1102__
|
||||
|
||||
#if defined(__gfx1103__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1103 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1103 = false;
|
||||
#endif // __gfx1103__
|
||||
|
||||
#if defined(__gfx1150__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1150 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1150 = false;
|
||||
#endif // __gfx1150__
|
||||
|
||||
#if defined(__gfx1151__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1151 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1151 = false;
|
||||
#endif // __gfx1151__
|
||||
|
||||
#if defined(__gfx1152__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1152 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1152 = false;
|
||||
#endif // __gfx1152__
|
||||
|
||||
#if defined(__gfx11_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = false;
|
||||
#endif // __gfx11_generic__
|
||||
|
||||
// GFX12
|
||||
#if defined(__gfx1200__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1200 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1200 = false;
|
||||
#endif // __gfx1200__
|
||||
|
||||
#if defined(__gfx1201__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1201 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1201 = false;
|
||||
#endif // __gfx1201__
|
||||
|
||||
#if defined(__gfx12_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = false;
|
||||
#endif // __gfx12_generic__
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper to count the number of times an item is contained within a list of values
|
||||
* @tparam T The type of the search value
|
||||
* @tparam Ts The types of the search list values
|
||||
* @param search The value to search for
|
||||
* @param searchList The list of values to search in
|
||||
* @return true if the search value is in the search list, false otherwise
|
||||
*/
|
||||
template <typename T, typename... Ts>
|
||||
// TODO: c++20 concept requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >=
|
||||
// 1))
|
||||
CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... searchList)
|
||||
{
|
||||
static_assert((std::is_convertible<Ts, T>::value && ...),
|
||||
"All search list values must be convertible to the search value type");
|
||||
static_assert(sizeof...(Ts) >= 1, "At least one value must be provided to search in");
|
||||
|
||||
return (static_cast<uint32_t>(search == static_cast<T>(searchList)) + ...);
|
||||
}
|
||||
|
||||
#define CK_TILE_COMPILER_TARGETS_LIST \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX908, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_3_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1100, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1101, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1102, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1103, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX12_GENERIC
|
||||
|
||||
// Sanity check: make sure only one target architecture is defined during device compile
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
|
||||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 1u,
|
||||
"Only one target architecture can be defined during device compile");
|
||||
|
||||
// Sanity check: make sure no device target architecture is defined during host compile
|
||||
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
|
||||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 0u,
|
||||
"No device target architecture can be defined during host compile");
|
||||
|
||||
} // namespace ck_tile::core
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,10 @@ struct ignore_t
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
template <typename... T>
|
||||
constexpr void operator()(T&&...) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -128,6 +128,25 @@ struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper to check if a value is in a list of values
|
||||
* @tparam T The type of the search value
|
||||
* @tparam Ts The types of the search list values
|
||||
* @param search The value to search for
|
||||
* @param searchList The list of values to search in
|
||||
* @return true if the search value is in the search list, false otherwise
|
||||
*/
|
||||
template <typename T, typename... Ts>
|
||||
// TODO: c++20 requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >= 1))
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_any_value_of(T search, Ts... searchList)
|
||||
{
|
||||
static_assert((std::is_convertible<Ts, T>::value && ...),
|
||||
"All searchList values must be convertible to the type of search");
|
||||
static_assert(sizeof...(Ts) >= 1, "searchList must contain at least one value");
|
||||
|
||||
return ((search == static_cast<T>(searchList)) || ...);
|
||||
}
|
||||
|
||||
// Helper to check if a type is a specialization of a given template
|
||||
template <typename Test, template <typename...> class RefTemplate>
|
||||
struct is_specialization_of : std::false_type
|
||||
|
||||
Reference in New Issue
Block a user