First look at mfma / wmma unification (#2704)

* First look at mfma / wmma unification

* Refactor

* Re-org file structure

* Restructure transform selection and WaveWiseMma class

* Update license files. Add missing gfx1151 support. Change wave size for HOST to 1. Update datatypes naming consistency

* Fixes default MmaSelector implentation

* Adds unit tests for amdgcn_mma and arch

* Consolidate common arch id checks to constexpr functions. Strongly type ids as amdgcn_target_arch_id object.

* Refactor is_any_value_of

* Fixes mma_selector logic

* Fix typo

* Add mma selector test for tile decomposition

* Fix compilation of mma.hpp

* Revert back to c++17 compatibility

* Fix compiler error by returning index_t from get_warp_size()

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Fixes compiler error for missing is_wave32() function

* Fixes compiler error for host wave_size() should be 64

* Fixes compiler errors where __cpp_concepts is not defined

* Fixes compiler errors where __cpp_concepts is not defined

* Fix test failure for host is wave64 by default

---------

Co-authored-by: Chris Millette <you@example.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Christopher Millette
2025-11-24 10:39:59 -07:00
committed by GitHub
parent 8111572785
commit b9c6cb1452
27 changed files with 3726 additions and 11 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
@@ -9,6 +9,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/utility/ignore.hpp"
@@ -60,15 +61,753 @@ enum struct memory_operation_enum : std::uint16_t
add
};
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
namespace core::arch {
/**
* @enum amdgcn_target_id
* @brief Defines constants for AMDGCN architecture target IDs
*/
enum struct amdgcn_target_id
{
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
GFX908 = 0x0908, // MI-100...
GFX90A = 0x090A,
GFX942 = 0x0942,
GFX950 = 0x0950,
GFX1030 = 0x1030,
GFX1031 = 0x1031,
GFX1032 = 0x1032,
GFX1034 = 0x1034,
GFX1035 = 0x1035,
GFX1036 = 0x1036,
GFX103_GENERIC = 0x103F,
GFX1100 = 0x1100,
GFX1101 = 0x1101,
GFX1102 = 0x1102,
GFX1103 = 0x1103,
GFX1150 = 0x1150,
GFX1151 = 0x1151,
GFX1152 = 0x1152,
GFX11_GENERIC = 0x11FF,
GFX1200 = 0x1200,
GFX1201 = 0x1201,
GFX12_GENERIC = 0x12FF,
HOST = 0x0000,
};
enum struct amdgcn_target_family_id
{
GFX9 = 0x09,
GFX10_3 = 0x10,
GFX11 = 0x11,
GFX12 = 0x12,
HOST = 0x00,
};
enum struct amdgcn_target_arch_id
{
CDNA = 0x01,
RDNA = 0x02,
HOST = 0x00,
};
enum struct amdgcn_target_wave_size_id
{
WAVE32 = 32u,
WAVE64 = 64u,
HOST = 64u, // TODO: Is this correct? Should the host default to 64 or 1?
};
#if 1 //__cplusplus <= 201703L
template <amdgcn_target_id TargetId = amdgcn_target_id::HOST,
amdgcn_target_family_id FamilyId = amdgcn_target_family_id::HOST,
amdgcn_target_arch_id ArchId = amdgcn_target_arch_id::HOST,
amdgcn_target_wave_size_id WaveSizeId = amdgcn_target_wave_size_id::HOST>
struct amdgcn_target
{
static constexpr amdgcn_target_id TARGET_ID = TargetId;
static constexpr amdgcn_target_family_id FAMILY_ID = FamilyId;
static constexpr amdgcn_target_arch_id ARCH_ID = ArchId;
static constexpr amdgcn_target_wave_size_id WAVE_SIZE_ID = WaveSizeId;
};
template <amdgcn_target_id targetId>
static constexpr auto make_amdgcn_gfx9_target()
{
return amdgcn_target<targetId,
amdgcn_target_family_id::GFX9,
amdgcn_target_arch_id::CDNA,
amdgcn_target_wave_size_id::WAVE64>{};
}
template <amdgcn_target_id targetId>
static constexpr auto make_amdgcn_gfx10_3_target()
{
return amdgcn_target<targetId,
amdgcn_target_family_id::GFX10_3,
amdgcn_target_arch_id::RDNA,
amdgcn_target_wave_size_id::WAVE32>{};
}
template <amdgcn_target_id targetId>
static constexpr auto make_amdgcn_gfx11_target()
{
return amdgcn_target<targetId,
amdgcn_target_family_id::GFX11,
amdgcn_target_arch_id::RDNA,
amdgcn_target_wave_size_id::WAVE32>{};
}
template <amdgcn_target_id targetId>
static constexpr auto make_amdgcn_gfx12_target()
{
return amdgcn_target<targetId,
amdgcn_target_family_id::GFX12,
amdgcn_target_arch_id::RDNA,
amdgcn_target_wave_size_id::WAVE32>{};
}
template <typename CompilerTarget, amdgcn_target_id... TargetIds>
static constexpr auto is_target_id_any_of()
{
return is_any_value_of(CompilerTarget::TARGET_ID, TargetIds...);
}
template <typename CompilerTarget, amdgcn_target_family_id... FamilyIds>
static constexpr auto is_target_family_any_of()
{
return is_any_value_of(CompilerTarget::FAMILY_ID, FamilyIds...);
}
template <typename CompilerTarget>
static constexpr bool is_target_family_gfx9()
{
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX9;
}
template <typename CompilerTarget>
static constexpr bool is_target_family_gfx10_3()
{
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX10_3;
}
template <typename CompilerTarget>
static constexpr bool is_target_family_gfx11()
{
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX11;
}
template <typename CompilerTarget>
static constexpr bool is_target_family_gfx12()
{
return CompilerTarget::FAMILY_ID == amdgcn_target_family_id::GFX12;
}
template <typename CompilerTarget>
static constexpr bool is_target_arch_cdna()
{
return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::CDNA;
}
template <typename CompilerTarget>
static constexpr bool is_target_arch_rdna()
{
return CompilerTarget::ARCH_ID == amdgcn_target_arch_id::RDNA;
}
template <typename CompilerTarget>
static constexpr bool is_target_wave_size_32()
{
return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
}
template <typename CompilerTarget>
static constexpr bool is_target_wave_size_64()
{
return CompilerTarget::WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
}
// Helper to map compiler state to target arch id
#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx9_target<amdgcn_target_id::TARGET_ID>(); \
} \
else
#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx10_3_target<amdgcn_target_id::TARGET_ID>(); \
} \
else
#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx11_target<amdgcn_target_id::TARGET_ID>(); \
} \
else
#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx12_target<amdgcn_target_id::TARGET_ID>(); \
} \
else
/**
* @brief Returns the amdgcn_target of the current compiler pass.
* @note This is where we tie the compiler state to our internal target architecture representation
* at compile time.
*/
constexpr auto get_compiler_target()
{
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
// Return HOST by default
if constexpr(amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE)
{
return amdgcn_target<>{};
}
}
// Cleanup
#undef MAP_COMPILER_STATE_TO_GFX9_TARGET
#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
#undef MAP_COMPILER_STATE_TO_GFX11_TARGET
#undef MAP_COMPILER_STATE_TO_GFX12_TARGET
// Sanity check: device compile must have a valid target architecture
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
"Device compile must have a valid target device architecture");
// Sanity check: host compile must have HOST target architecture
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
"Host compile must target HOST architecture");
// TODO: c++20 use the make functions and constexpr if to avoid string construction and find at
// runtime
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID(NAME_STRING, TARGET_ID) \
if(str.find(NAME_STRING) != std::string::npos) \
{ \
return amdgcn_target_id::TARGET_ID; \
} \
else
/**
* @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value.
* Returns amdgcn_target_arch_id::HOST if no match is found.
* Matches if the input contains the architecture substring.
* Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info.
*/
// TODO: c++20 constexpr if and string_view to avoid std::string construction and find at runtime
// TODO: c++20 return amdgcn_target instance instead of just the target id
CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* testStr)
{
auto str = std::string(testStr);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx908", GFX908);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx90a", GFX90A);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx942", GFX942);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx950", GFX950);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx10_3_generic", GFX103_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1100", GFX1100);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1101", GFX1101);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1102", GFX1102);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1103", GFX1103);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx12_generic", GFX12_GENERIC);
// Default case: return HOST target if no match is found
return amdgcn_target_id::HOST;
}
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID
/**
* @brief SFINAE enabler for a compiler target if the target id is in the list of supported target
* ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908
*/
template <typename CompilerTarget, amdgcn_target_id... SupportedTargetIds>
using enable_if_target_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, SupportedTargetIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the family id is in the list of supported family
* ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetFamilyIds The list of supported family ids, e.g.,
* amdgcn_target_family_id::GFX9
*/
template <typename CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
using enable_if_target_family_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget::FAMILY_ID, SupportedTargetFamilyIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA
*/
template <typename CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
using enable_if_target_arch_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget::ARCH_ID, SupportedTargetArchIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave
* size ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g.,
* amdgcn_target_wave_size_id::WAVE64
*/
template <typename CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
using enable_if_target_wave_size_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget::WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
/// Specialized enablers for common families, architectures, and wave sizes ///
/**
* @brief SFINAE enabler for GFX9 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_family_gfx9_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
/**
* @brief SFINAE enabler for GFX10.3 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_family_gfx10_3_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
/**
* @brief SFINAE enabler for GFX11 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_family_gfx11_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
/**
* @brief SFINAE enabler for GFX12 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_family_gfx12_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
/**
* @brief SFINAE enabler for CDNA architecture targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_arch_cdna_t =
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
/**
* @brief SFINAE enabler for RDNA architecture targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_arch_rdna_t =
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
/**
* @brief SFINAE enabler for WAVE32 size targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_wave32_t =
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
/**
* @brief SFINAE enabler for WAVE64 size targets
* @tparam CompilerTarget The compiler target to check
*/
template <typename CompilerTarget>
using enable_if_target_wave64_t =
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
#elif __cplusplus >= 202002L
struct amdgcn_target
{
// Target architecture identifiers
// These are set to HOST (0) by default
// TARGET_ID is the specific architecture id (e.g., GFX908)
// FAMILY_ID is the architecture family id (e.g., GFX9)
// ARCH_ID is the architecture class id (e.g., CDNA, RDNA)
// WAVE_SIZE_ID is the wavefront size id (e.g., WAVE32, WAVE64)
const amdgcn_target_id TARGET_ID = amdgcn_target_id::HOST;
const amdgcn_target_family_id FAMILY_ID = amdgcn_target_family_id::HOST;
const amdgcn_target_arch_id ARCH_ID = amdgcn_target_arch_id::HOST;
const amdgcn_target_wave_size_id WAVE_SIZE_ID = amdgcn_target_wave_size_id::HOST;
};
static constexpr auto make_amdgcn_gfx10_3_target(amdgcn_target_id targetId)
{
return amdgcn_target{.TARGET_ID = targetId,
.FAMILY_ID = amdgcn_target_family_id::GFX10_3,
.ARCH_ID = amdgcn_target_arch_id::RDNA,
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
}
static constexpr auto make_amdgcn_gfx9_target(amdgcn_target_id targetId)
{
return amdgcn_target{.TARGET_ID = targetId,
.FAMILY_ID = amdgcn_target_family_id::GFX9,
.ARCH_ID = amdgcn_target_arch_id::CDNA,
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE64};
}
static constexpr auto make_amdgcn_gfx11_target(amdgcn_target_id targetId)
{
return amdgcn_target{.TARGET_ID = targetId,
.FAMILY_ID = amdgcn_target_family_id::GFX11,
.ARCH_ID = amdgcn_target_arch_id::RDNA,
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
}
static constexpr auto make_amdgcn_gfx12_target(amdgcn_target_id targetId)
{
return amdgcn_target{.TARGET_ID = targetId,
.FAMILY_ID = amdgcn_target_family_id::GFX12,
.ARCH_ID = amdgcn_target_arch_id::RDNA,
.WAVE_SIZE_ID = amdgcn_target_wave_size_id::WAVE32};
}
static constexpr bool is_target_family_gfx9(amdgcn_target target)
{
return target.FAMILY_ID == amdgcn_target_family_id::GFX9;
}
static constexpr bool is_target_family_gfx10_3(amdgcn_target target)
{
return target.FAMILY_ID == amdgcn_target_family_id::GFX10_3;
}
static constexpr bool is_target_family_gfx11(amdgcn_target target)
{
return target.FAMILY_ID == amdgcn_target_family_id::GFX11;
}
static constexpr bool is_target_family_gfx12(amdgcn_target target)
{
return target.FAMILY_ID == amdgcn_target_family_id::GFX12;
}
static constexpr bool is_target_arch_cdna(amdgcn_target target)
{
return target.ARCH_ID == amdgcn_target_arch_id::CDNA;
}
static constexpr bool is_target_arch_rdna(amdgcn_target target)
{
return target.ARCH_ID == amdgcn_target_arch_id::RDNA;
}
static constexpr bool is_target_wave_size_32(amdgcn_target target)
{
return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE32;
}
static constexpr bool is_target_wave_size_64(amdgcn_target target)
{
return target.WAVE_SIZE_ID == amdgcn_target_wave_size_id::WAVE64;
}
// Helper to map compiler state to target arch id
#define MAP_COMPILER_STATE_TO_GFX10_3_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
}
#define MAP_COMPILER_STATE_TO_GFX9_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
}
#define MAP_COMPILER_STATE_TO_GFX11_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
}
#define MAP_COMPILER_STATE_TO_GFX12_TARGET(COMPILER_STATE, TARGET_ID) \
if constexpr(amdgcn_compiler_target_state::COMPILER_STATE) \
{ \
return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
}
/*! @brief Returns the amdgcn_target of the current compiler pass.
* @note This is where we tie the compiler state to our internal target architecture representation
* at compile time.
*/
CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
{
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX908, GFX908);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX90A, GFX90A);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX942, GFX942);
MAP_COMPILER_STATE_TO_GFX9_TARGET(CK_TILE_ARCH_GFX950, GFX950);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX10_3_GENERIC, GFX103_GENERIC);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1100, GFX1100);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1101, GFX1101);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1102, GFX1102);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1103, GFX1103);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152);
MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201);
MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX12_GENERIC, GFX12_GENERIC);
// Default to HOST
return amdgcn_target{};
}
// Cleanup
#undef MAP_COMPILER_STATE_TO_GFX9_TARGET
#undef MAP_COMPILER_STATE_TO_GFX10_3_TARGET
#undef MAP_COMPILER_STATE_TO_GFX11_TARGET
#undef MAP_COMPILER_STATE_TO_GFX12_TARGET
// Sanity check: device compile must have a valid target architecture
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
get_compiler_target().TARGET_ID != amdgcn_target_id::HOST,
"Device compile must have a valid target device architecture");
// Sanity check: host compile must have HOST target architecture
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
get_compiler_target().TARGET_ID == amdgcn_target_id::HOST,
"Host compile must target HOST architecture");
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET(NAME_STRING, TARGET_ID) \
if constexpr(str.find(NAME_STRING) != std::string::npos) \
{ \
return make_amdgcn_gfx9_target(amdgcn_target_id::TARGET_ID); \
} \
else
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET(NAME_STRING, TARGET_ID) \
if constexpr(str.find(NAME_STRING) != std::string::npos) \
{ \
return make_amdgcn_gfx10_3_target(amdgcn_target_id::TARGET_ID); \
} \
else
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET(NAME_STRING, TARGET_ID) \
if constexpr(str.find(NAME_STRING) != std::string::npos) \
{ \
return make_amdgcn_gfx11_target(amdgcn_target_id::TARGET_ID); \
} \
else
#define MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET(NAME_STRING, TARGET_ID) \
if constexpr(str.find(NAME_STRING) != std::string::npos) \
{ \
return make_amdgcn_gfx12_target(amdgcn_target_id::TARGET_ID); \
} \
else
/**
* @brief Converts a lower-case string to the corresponding amdgcn_target_arch_id value.
* Returns amdgcn_target_arch_id::HOST if no match is found.
* Matches if the input contains the architecture substring.
* Example: "gfx908", "gfx90a", "gfx1100", etc. can be parsed from hip runtime info.
*/
CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* testStr)
{
auto str = std::string(testStr);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx908", GFX908);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx90a", GFX90A);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx942", GFX942);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET("gfx950", GFX950);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx10_3_generic", GFX103_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1100", GFX1100);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1101", GFX1101);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1102", GFX1102);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1103", GFX1103);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx12_generic", GFX12_GENERIC);
// Default case
return amdgcn_target{};
}
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX9_TARGET
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET
#undef MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET
/**
* @brief SFINAE enabler for a compiler target if the target id is in the list of supported target
* ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetIds The list of supported target ids, e.g., amdgcn_target_id::GFX908
*/
template <amdgcn_target CompilerTarget, amdgcn_target_id... SupportedTargetIds>
using enable_if_target_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget.TARGET_ID, SupportedTargetIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the family id is in the list of supported family
* ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetFamilyIds The list of supported family ids, e.g.,
* amdgcn_target_family_id::GFX9
*/
template <amdgcn_target CompilerTarget, amdgcn_target_family_id... SupportedTargetFamilyIds>
using enable_if_target_family_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget.FAMILY_ID, SupportedTargetFamilyIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the arch id is in the list of supported arch ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetArchIds The list of supported arch ids, e.g., amdgcn_target_arch_id::CDNA
*/
template <amdgcn_target CompilerTarget, amdgcn_target_arch_id... SupportedTargetArchIds>
using enable_if_target_arch_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget.ARCH_ID, SupportedTargetArchIds...)>;
/**
* @brief SFINAE enabler for a compiler target if the wave size id is in the list of supported wave
* size ids
* @tparam CompilerTarget The compiler target to check
* @tparam SupportedTargetWaveSizeIds The list of supported wave size ids, e.g.,
* amdgcn_target_wave_size_id::WAVE64
*/
template <amdgcn_target CompilerTarget, amdgcn_target_wave_size_id... SupportedTargetWaveSizeIds>
using enable_if_target_wave_size_id_t =
std::enable_if_t<is_any_value_of(CompilerTarget.WAVE_SIZE_ID, SupportedTargetWaveSizeIds...)>;
/// Specialized enablers for common families, architectures, and wave sizes ///
/**
* @brief SFINAE enabler for GFX9 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_family_gfx9_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX9>;
/**
* @brief SFINAE enabler for GFX10.3 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_family_gfx10_3_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX10_3>;
/**
* @brief SFINAE enabler for GFX11 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_family_gfx11_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX11>;
/**
* @brief SFINAE enabler for GFX12 family targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_family_gfx12_t =
enable_if_target_family_id_t<CompilerTarget, amdgcn_target_family_id::GFX12>;
/**
* @brief SFINAE enabler for CDNA architecture targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_arch_cdna_t =
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::CDNA>;
/**
* @brief SFINAE enabler for RDNA architecture targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_arch_rdna_t =
enable_if_target_arch_id_t<CompilerTarget, amdgcn_target_arch_id::RDNA>;
/**
* @brief SFINAE enabler for WAVE32 size targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_wave32_t =
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE32>;
/**
* @brief SFINAE enabler for WAVE64 size targets
* @tparam CompilerTarget The compiler target to check
*/
template <amdgcn_target CompilerTarget>
using enable_if_target_wave64_t =
enable_if_target_wave_size_id_t<CompilerTarget, amdgcn_target_wave_size_id::WAVE64>;
#endif // __cplusplus <= 201703L
} // namespace core::arch
CK_TILE_HOST bool is_wave32()
{
hipDeviceProp_t props{};
@@ -86,6 +825,13 @@ CK_TILE_HOST bool is_wave32()
return props.major > 9;
}
/*! @brief Returns the amdgcn_wave_size of the current compiler pass
*/
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
return static_cast<index_t>(core::arch::get_compiler_target().WAVE_SIZE_ID);
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }

View File

@@ -0,0 +1,118 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/utility/ignore.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct Unsupported
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
*/
struct Unsupported;
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @concept MmaOpI
* @brief Expresses the meta-data interface required for each MmaOp policy.
*/
template <typename MmaOp>
concept MmaOpI = requires(MmaOp op) {
// Requires an op context
typename MmaOp::OpType;
// Captures types for inputs / outputs to mma function
typename MmaOp::AVecType;
typename MmaOp::BVecType;
typename MmaOp::CVecType;
// Captures CK-specific layout properties
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
// Static exec function
{
MmaOp::exec(
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
} -> std::convertible_to<typename MmaOp::CVecType>;
};
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @class amdgcn_mma
* @brief This is the default MmaOp policy.
* Instances of this class are to be used as MmaOp policies.
* Light builtin wrapper for mfma / wmma instructions. This class's job is to
* provide a uniform interface to invoke the appropriate instruction
* based on the template parameters provided. This interface is to bridge
* the gap between the ck_tile API types and the native __builtin types.
* @tparam ADataType Datatype of input A
* @tparam BDataType Datatype of input B
* @tparam CDataType Datatype of accumulator
* @tparam BlockM M-dimension of mma block
* @tparam BlockN N-dimension of mma block
* @tparam BlockK K-dimension of mma block
* @tparam CtrlFlags Control flags for mma operation
* @tparam CompilerTarget The current compiler target
* @tparam Enabler SFINAE enabler
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockK,
typename CtrlFlags,
typename CompilerTarget,
typename Enabler = void>
struct amdgcn_mma
{
// The base instance is unsupported because there is no __builtin to wrap.
using OpType = Unsupported;
// Interface types for A, B, C vectors types
using AVecType = ext_vector_t<ADataType, 1>;
using BVecType = ext_vector_t<BDataType, 1>;
using CVecType = ext_vector_t<CDataType, 1>;
// Layout constants - default to 0
static constexpr index_t kAMBlock = 0;
static constexpr index_t kBNBlock = 0;
static constexpr index_t kAMLane = 0;
static constexpr index_t kBNLane = 0;
static constexpr index_t kABKLane = 0;
static constexpr index_t kABKPerLane = 0;
static constexpr index_t kCMLane = 0;
static constexpr index_t kCNLane = 0;
static constexpr index_t kCM0PerLane = 0;
static constexpr index_t kCM1PerLane = 0;
// This is a default pass-through implementation that doesn't do anything practical.
CK_TILE_DEVICE static CVecType const&
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
{
ignore(regsA, regsB);
return regsC; // No-op, just return C
}
};
} // namespace ck_tile::core::arch::mma
// Include the implementations
#include "wmma/wmma.hpp"
#include "mfma/mfma.hpp"

View File

@@ -0,0 +1,10 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// Include the architecture-specific MFMA implementations and traits
#include "mfma_gfx9.hpp"
#include "mfma_traits.hpp"
#include "mfma_selector.hpp"
#include "mfma_transforms.hpp"

View File

@@ -0,0 +1,162 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "mfma_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
// This is because some built-ins are only available on certain target ids.
// We can also do things such add some padding specializations for when we need to use
// smaller values of K that aren't directly supported by the built-ins.
// For flexibility, it is recommended that for each backend wrapper it supports at least
// one packed register for each input to be able to process smaller K values by padding.
/**
* @struct DefaultMmaCtrlFlags
* @brief Default MFMA flags, no broadcasting or rotation of inputs
*/
struct DefaultMfmaCtrlFlags
{
static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
static constexpr uint32_t Abid = 0; // ABID flag, default 0
static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
};
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @concept CtrlFlagsGfx9I
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
*/
template <typename CtrlFlags>
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
// Flag members for Gfx9 MFMA instructions
{ CtrlFlags::Cbsz } -> std::convertible_to<int>;
{ CtrlFlags::Abid } -> std::convertible_to<int>;
{ CtrlFlags::Blgp } -> std::convertible_to<int>;
};
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
{
// Mfma operation type
using OpType = MfmaOp;
// Register types
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
using CVecType = ext_vector_t<fp32_t, 4>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 4;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
*
* This specialization implements the MFMA instruction for fp16_t A and B
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
*
* @tparam CtrlFlags Control flags for the MFMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
CtrlFlags,
CompilerTarget,
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
{
using OpType = MfmaOp;
// Packed register types
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<fp32_t, 4>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec,
bVec,
cVec,
static_cast<int>(CtrlFlags::Cbsz),
static_cast<int>(CtrlFlags::Abid),
static_cast<int>(CtrlFlags::Blgp))};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,189 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "mfma_traits.hpp"
#include "mfma_gfx9.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class MfmaDefaultSelector
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
* This implements the K dimension search strategy to find the largest supported MFMA
* instruction for the given M/N block sizes and datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through
implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Block M dimension size
* @tparam BlockN Block N dimension size
* @tparam BlockKTest Current Block K dimension size to test
* @tparam CompilerTarget The compiler target
* @note Here we assume that BlockKTest is always a power-of-two integer.
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
* each time.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
struct MfmaDefaultSelector
{
private:
// Define our candidate MFMA implementation for the current parameters
using CandidateOp =
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
// and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
CandidateOp,
typename MfmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest / 2u,
CompilerTarget>::SelectedOp>;
};
/**
* @struct MfmaDefaultSelector
* @brief Implements the base case for the default MFMA selector when no supported instruction is
* found.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Block M dimension size
* @tparam BlockN Block N dimension size
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
{
// Default unsupported pass-through if no instruction is found
using SelectedOp =
amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
1u,
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
CompilerTarget>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
* This implements the M/N block size search strategy to find the largest supported MFMA
* instruction for the given datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
{
private:
// Provide the default depth-K search strategy for each class of common MFMA shapes.
// Start searching from the largest K dimension MFMA shape down to the smallest.
using CandidateOp4x4 =
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 4u, 4u, 4u, CompilerTarget>::
SelectedOp;
using CandidateOp16x16 = typename MfmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
128u,
CompilerTarget>::SelectedOp;
using CandidateOp32x32 = typename MfmaDefaultSelector<ADataType,
BDataType,
CDataType,
32u,
32u,
64u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp =
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
SelectedOp;
// Traits for each candidate
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the MFMA shape
static constexpr bool IsSupported4x4 =
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
(FragM % CandidateTraits32x32::BlockM == 0u) &&
(FragN % CandidateTraits32x32::BlockN == 0u) &&
(FragK % CandidateTraits32x32::BlockK == 0u);
public:
// Select the largest supported MFMA operation for the given fragment shape
using SelectedOp = std::conditional_t<
IsSupported32x32,
CandidateOp32x32,
std::conditional_t<IsSupported16x16,
CandidateOp16x16,
std::conditional_t<IsSupported4x4, CandidateOp4x4, DefaultOp>>>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,44 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @struct MfmaOp
* @brief Meta-tag for the MFMA operation. This will be used in the MmaOp policies to
* identify the operation as an MFMA operation.
*/
struct MfmaOp;
/**
* @class is_mma_op_mfma
* @brief Trait to check if MmaOp is an MFMA operation
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp, typename = void>
struct is_mma_op_mfma : std::false_type
{
};
/**
* @struct is_mma_op_mfma
* @brief MmaOp specialization for MFMA operations, confirming the OpType matches MfmaOp
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp>
// TODO: c++20 requires
struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, MfmaOp>>>
: std::true_type
{
};
/**
* @brief Convenience evaluator for is_mma_op_mfma trait
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp>
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,38 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct MmaDefaultTransformsGfx9
* @brief Implements the default MMA transforms for gfx9 targets
*/
struct MmaDefaultTransformsGfx9
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @struct MmaTransformsDefaultSelector
* @brief Implements the default MMA transforms selection for gfx9 targets
* @tparam MmaOp Mma operation
* @tparam CompilerTarget The compiler target
*/
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
template <typename MmaOp, typename CompilerTarget>
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx9_t<CompilerTarget>>
{
using SelectedTransforms = MmaDefaultTransformsGfx9;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,234 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "amdgcn_mma.hpp"
#include "mma_selector.hpp"
#include "mma_traits.hpp"
#include "mma_transforms.hpp"
#include "mfma/mfma.hpp"
#include "wmma/wmma.hpp"
namespace ck_tile::core::arch::mma {
/*! @enum MmaAccumPolicy
* @brief Accumulation order for Mma decomposition
*/
enum struct MmaAccumPolicy
{
// Decomposition and accumulation in row-major block order
ROW_MAJOR,
// Decomposition and accumulation in col-major block order
COL_MAJOR
};
/**
* @class Mma
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
* (C: FragM x FragN).
* @tparam ADataType Data type of input fragment A
* @tparam BDataType Data type of input fragment B
* @tparam CDataType Data type of input/output fragment C (accumulator)
* @tparam FragM Mma fragment M dimension
* @tparam FragN Mma fragment K dimension
* @tparam FragK Mma fragment M dimension
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
* order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
* wmma)
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
* that can adapt to different hardware capabilities and requirements.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
typename MmaOp =
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
struct WaveWiseMma
{
using BlockWiseMmaOp = MmaOp;
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
// Block dimensions
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
// Block counts for decomposition
constexpr static uint32_t BlocksM = FragM / BlockM;
constexpr static uint32_t BlocksN = FragN / BlockN;
constexpr static uint32_t BlocksK = FragK / BlockK;
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
// Vector types for packed registers in each block
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
// Buffer types for fragments
using ABufferType = AVecType[BlocksM][BlocksK];
using BBufferType = BVecType[BlocksN][BlocksK];
using CBufferType = CVecType[BlocksM][BlocksN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
private:
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input fragments to the native vector types
// required by the BlockWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT const&>(inputBuffer);
}
template <typename DstT, typename SrcT>
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
{
// TODO: Implement formatting logic as needed.
// This is intended to convert input fragments to the native vector types
// required by the BlockWiseMma operation for iteration
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
return reinterpret_cast<DstT&>(inputBuffer);
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Col-major" accumulation over the M-dimension blocks first.
// Pseudo code here, but we would basically iterate over the blocks in col-major order
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
{
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
{
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
{
c_frag[bm][bn] =
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output fragment format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
/*! @brief Execute Mma in row-major accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
{
// We implement an example wave-tile pipeline here.
// First, we apply the necessary transforms to the input fragments,
// then we convert the result into buffers of native vector formats
// that we can easily index. Native vector formats are necessary inputs
// to the given MmaOp exec function.
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
// "Row-major" accumulation over the N-dimension blocks first.
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
// We also have to ensure that the incoming vector fragments are converted to native vector
// types before passing to the BlockWiseMma exec function.
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
{
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
{
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
{
c_frag[bm][bn] =
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
}
}
}
// Convert native vector results back to the output fragment format
// and then return after we apply the final output transform.
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
}
public:
/*! @brief Forward to Mma operation with specified accumulation order.
* @tparam VecTA The input fragment A vector type
* @tparam VecTB The input fragment B vector type
* @tparam VecTC The input/output fragment C vector type
*/
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
return exec_row_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
return exec_col_major(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,63 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @class MmaDefaultSelector
* @brief Implements a default mma selector strategy for the current target architecture.
* This is simply intended as a default selection strategy for mma instruction operations.
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
* select the instruction with the largest K dimension that is supported on the current target
* architecture.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Fragment M dimension
* @tparam FragN Fragment N dimension
* @tparam FragK Fragment K dimension
* @tparam CompilerTarget The compiler target
* @tparam Enable SFINAE enabler
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
* correspond to the size of the individual MMA instructions being used to compute the overall in
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
* equal to the Block sizes.
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget,
typename Enable = void>
// TODO c++20 requires
struct MmaDefaultSelector
{
// By default, no selection is made, and we fall back to a pass-through unsupported
// implementation. This is because we do not have any knowledge of the target architecture here.
using SelectedOp =
amdgcn_mma<ADataType, BDataType, CDataType, FragM, FragN, FragK, void, amdgcn_target<>>;
};
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @concept MmaSelectorI
* @brief Expresses the required members for each MmaSelector class.
*/
template <typename MmaSelector>
concept MmaSelectorI = requires(MmaSelector op) {
// Selectors should have a resulting SelectedOp type
typename MmaSelector::SelectedOp;
};
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
} // namespace ck_tile::core::arch::mma
// Include the implementations
#include "wmma/wmma_selector.hpp"
#include "mfma/mfma_selector.hpp"

View File

@@ -0,0 +1,151 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "amdgcn_mma.hpp"
#include "mfma/mfma_traits.hpp"
#include "wmma/wmma_traits.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class is_mma_op_supported
* @brief Trait to check if MmaOp is supported
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
// TODO: c++20 template <MmaOpI MmaOp, typename = void>
template <typename MmaOp, typename = void>
struct is_mma_op_supported : std::true_type
{
};
/**
* @struct is_mma_op_supported
* @brief The MmaOp is unsupported specialization
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
// TODO: c++20 template <MmaOpI MmaOp>
template <typename MmaOp>
// TODO: c++20 requires
struct is_mma_op_supported<MmaOp,
std::enable_if_t<std::is_same_v<typename MmaOp::OpType, Unsupported>>>
: std::false_type
{
};
/**
* @brief Convenience evaluation of is_mma_op_supported
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
// TODO: c++20 template <MmaOpI MmaOp>
template <typename MmaOp>
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
/**
* @class MmaOpParams
* @brief Reflects the template parameters of a given MmaOp
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
// TODO: c++20 template <MmaOpI MmaOp>
template <typename MmaOp>
struct MmaOpParams;
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @concept MmaOpParamsI
* @brief Expresses the required members for each MmaOp
*/
template <typename MmaOpParams>
concept MmaOpParamsI = requires(MmaOpParams op) {
// Capture template parameters
typename MmaOpParams::ADataType;
typename MmaOpParams::BDataType;
typename MmaOpParams::CDataType;
typename MmaOpParams::CtrlFlags;
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
};
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @struct MmaOpParams
* @brief Reflects the template parameters of a given MmaOp
* @tparam ADataType_ Data type of matrix A
* @tparam BDataType_ Data type of matrix B
* @tparam CDataType_ Data type of the accumulator
* @tparam BlockM_ Size of the M dimension
* @tparam BlockN_ Size of the N dimension
* @tparam BlockK_ Size of the K dimension
* @tparam CtrlFlags_ Control flags for the MMA operation
* @tparam CompilerTarget_ The compiler target
*/
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
uint32_t BlockM_,
uint32_t BlockN_,
uint32_t BlockK_,
typename CtrlFlags_,
typename CompilerTarget_>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
struct MmaOpParams<amdgcn_mma<ADataType_,
BDataType_,
CDataType_,
BlockM_,
BlockN_,
BlockK_,
CtrlFlags_,
CompilerTarget_>>
{
// Capture incoming template parameters
using ADataType = ADataType_;
using BDataType = BDataType_;
using CDataType = CDataType_;
static constexpr uint32_t BlockM = BlockM_;
static constexpr uint32_t BlockN = BlockN_;
static constexpr uint32_t BlockK = BlockK_;
using CtrlFlags = CtrlFlags_;
using CompilerTarget = CompilerTarget_;
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
};
/**
* @class MmaOpTraits
* @brief Reflects the template parameters and static members of a given MmaOp.
* @tparam MmaOp The matrix multiply-accumulate operation
*/
template <typename MmaOp>
// TODO: c++20 template <MmaOpI MmaOp>
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
struct MmaOpTraits : public MmaOpParams<MmaOp>
{
// Capture internal MmaOp static members
using OpType = typename MmaOp::OpType;
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Capture layout parameters
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
static constexpr index_t kAMLane = MmaOp::kAMLane;
static constexpr index_t kBNLane = MmaOp::kBNLane;
static constexpr index_t kABKLane = MmaOp::kABKLane;
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
static constexpr index_t kCMLane = MmaOp::kCMLane;
static constexpr index_t kCNLane = MmaOp::kCNLane;
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
// Additional traits to identify the type of MmaOp at compile time
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
constexpr static bool IsSupported = is_mma_op_supported_v<MmaOp>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,48 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @struct PassThroughTransform
* @brief A no-op transform that passes through the input as-is.
*/
struct PassThroughTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
{
return std::forward<VecType>(v);
}
};
/**
* @class MmaTransformsDefaultSelector
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
* @tparam MmaOp The Mma operation type
* @tparam CompilerTarget The compiler target
* @tparam Enable SFINAE parameter for specialization
*/
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
#if defined(__cpp_concepts) && __cpp_concepts >= 201907L
/**
* @concept MmaTransformsI
* @brief Expresses the interface of required members for each MmaTransforms type.
*/
template <typename MmaTransforms>
concept MmaTransformsI = requires(MmaTransforms transforms) {
// Transforms should define TransformA, TransformB, TransformC, and TransformD types
typename MmaTransforms::ATransform;
typename MmaTransforms::BTransform;
typename MmaTransforms::CTransform;
typename MmaTransforms::DTransform;
};
#endif // defined(__cpp_concepts) && __cpp_concepts >= 201907L
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,34 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @enum WmmaCtrlFlags
* @brief Common wmma control flags for gfx11 and gfx12
*/
enum struct WmmaCtrlFlags : bool
{
// Only has an effect on gfx11 when the accumulator is 16-bit
// Determines which half of the 32-bit accum register to use
// Low = bits [15:0]
// High = bits[31:16]
LOW = false,
HIGH = true,
// Only has an effect on gfx11 / 12 when the input is 8-bit int
// Signage indicator of inputs / accum
UNSIGNED = false,
SIGNED = true
};
} // namespace ck_tile::core::arch::mma
// Include the architecture-specific WMMA implementations and traits
#include "wmma_gfx11.hpp"
#include "wmma_gfx12.hpp"
#include "wmma_selector.hpp"
#include "wmma_traits.hpp"
#include "wmma_transforms.hpp"

View File

@@ -0,0 +1,109 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "wmma_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
// - Duplicating A and B inputs
// - Handling C / D is always in b32, even for f16 accumulation.
// NOTE: Two suggestions:
// 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to
// duplicate the inputs as needed before calling the actual built-in. This may introduce
// some instruction overhead and violate single responsibility clauses, but keeps the logic
// contained within the backend wrapper.
// 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing
// pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform
// post-mma could implement D-shuffle transform. This may be cleaner and more flexible than
// trying to handle everything in the backend wrappers.
//
// This current example assumes duplication has already been done, and that C data shuffles have
// already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled
// data in C.
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
// This is because some built-ins are only available on certain target ids.
// We can also do things, such add some padding specializations for when we need to use
// smaller values of K that aren't directly supported by the built-ins.
// For flexibility, it is recommended that for each backend wrapper it supports at least
// one packed register for each input to be able to process smaller K values by padding.
/**
* @class DefaultWmmaFlags
* @brief Generates default WMMA control flags based on data types.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
*/
template <typename ADataType, typename BDataType, typename CDataType>
struct DefaultWmmaCtrlFlags
{
// Generate default flags for signage
// Only used currently for integer inputs / accum in gfx11 / gfx12
constexpr static WmmaCtrlFlags InputSignA =
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
constexpr static WmmaCtrlFlags InputSignB =
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
constexpr static WmmaCtrlFlags AccumSign =
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
// Generate default flags for accumulator destination bits.
// Only used if accumulation size is 16-bit in gfx11
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
{
// Wmma operation type
using OpType = WmmaOp;
// Register types (duplicated input / b32 accum)
using AVecType = ext_vector_t<fp16_t, 16>;
using BVecType = ext_vector_t<fp16_t, 16>;
using CVecType = ext_vector_t<fp32_t, 8>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 2;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 1;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,69 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "wmma_traits.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
namespace ck_tile::core::arch::mma {
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
// This is because some built-ins are only available on certain target ids.
// We can also do things, such add some padding specializations for when we need to use
// smaller values of K that aren't directly supported by the built-ins.
// For flexibility, it is recommended that for each backend wrapper it supports at least
// one packed register for each input to be able to process smaller K values by padding.
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
struct amdgcn_mma<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
16u,
CtrlFlags,
CompilerTarget,
enable_if_target_family_gfx12_t<CompilerTarget>>
{
// Wmma operation type
using OpType = WmmaOp;
// Register types
using AVecType = ext_vector_t<fp16_t, 8>;
using BVecType = ext_vector_t<fp16_t, 8>;
using CVecType = ext_vector_t<fp32_t, 8>;
// Layout constants
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 8;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 2;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 1;
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
{
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
}
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,161 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
namespace ck_tile::core::arch::mma {
/**
* @class WmmaDefaultSelector
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
* This implements the K dimension search strategy to find the largest supported WMMA
* instruction for the given M/N block sizes and datatypes.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam BlockKTest Size of the K dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
uint32_t BlockKTest,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
struct WmmaDefaultSelector
{
private:
// By default, let's assume no special flags for WMMA
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
// Define our candidate WMMA implementation for the current parameters
using CandidateOp = amdgcn_mma<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest,
CtrlFlags,
CompilerTarget>;
using CandidateTraits = MmaOpTraits<CandidateOp>;
public:
// If the candidate is supported (e.g., a backend implementation exists), then select it.
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
// and fall back to the unsupported pass-through implementation.
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
CandidateOp,
typename WmmaDefaultSelector<ADataType,
BDataType,
CDataType,
BlockM,
BlockN,
BlockKTest / 2u,
CompilerTarget>::SelectedOp>;
};
/**
* @struct WmmaDefaultSelector
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
* This implements the K dimension == 1, which is the base case for the recursive K dimension
* search. If no supported instruction is found, falls back to an unsupported pass-through
* implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam BlockM Size of the M dimension
* @tparam BlockN Size of the N dimension
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t BlockM,
uint32_t BlockN,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
{
// By default, let's assume no special flags for WMMA
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
// Default unsupported pass-through if no instruction is found
using SelectedOp =
amdgcn_mma<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CtrlFlags, CompilerTarget>;
};
/**
* @struct MmaDefaultSelector
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
* This implements the M/N block size search strategy to find the largest supported WMMA
* instruction for the given datatypes.
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
* @tparam ADataType Data type of matrix A
* @tparam BDataType Data type of matrix B
* @tparam CDataType Data type of the accumulator
* @tparam FragM Size of the M dimension of the fragment to decompose
* @tparam FragN Size of the N dimension of the fragment to decompose
* @tparam FragK Size of the K dimension of the fragment to decompose
* @tparam CompilerTarget The compiler target
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
typename CompilerTarget>
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
// TODO: c++20 requires
struct MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
CompilerTarget,
enable_if_target_arch_rdna_t<CompilerTarget>>
{
private:
// Provide the default depth-K search strategy for each class of common WMMA shapes.
// Start searching from the largest K dimension MFMA shape down to the smallest.
using CandidateOp16x16 = typename WmmaDefaultSelector<ADataType,
BDataType,
CDataType,
16u,
16u,
128u,
CompilerTarget>::SelectedOp;
// Default operation triggers pass-through
using DefaultOp =
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
SelectedOp;
// Traits for each candidate
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
// Check if each candidate is supported for the given fragment sizes
// For this case, we require the fragment sizes to be multiples of the WMMA shape
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
(FragM % CandidateTraits16x16::BlockM == 0u) &&
(FragN % CandidateTraits16x16::BlockN == 0u) &&
(FragK % CandidateTraits16x16::BlockK == 0u);
public:
// Select the largest supported WMMA operation for the given fragment shape
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,44 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::core::arch::mma {
/**
* @struct WmmaOp
* @brief Meta-tag for the WMMA operation. This will be used in the MmaOp struct to
* identify the operation as an WMMA operation.
*/
struct WmmaOp;
/**
* @class is_mma_op_wmma
* @brief Trait to check if MmaOp is an WMMA operation
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp, typename = void>
struct is_mma_op_wmma : std::false_type
{
};
/**
* @struct is_mma_op_wmma
* @brief MmaOp specialization for WMMA operations, confirming the OpType matches WmmaOp
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp>
// TODO: c++20 requires
struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, WmmaOp>>>
: std::true_type
{
};
/**
* @brief Convenience evaluator for is_mma_op_wmma trait
* @tparam MmaOp The matrix multiply-accumulate operation type to check
*/
template <typename MmaOp>
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
} // namespace ck_tile::core::arch::mma

View File

@@ -0,0 +1,112 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
namespace ck_tile::core::arch::mma {
/**
* @struct DuplicateTransform
* @brief Transform to duplicate low register elements to high register elements
*/
struct DuplicateTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
{
// TODO: Implement duplication logic to broadcast low
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
return std::forward<VecType>(v);
}
};
/**
* @struct PadTransform
* @brief Transform to pad data from original type to b32 type
*/
struct PadTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
{
// TODO: Implement b32 padding logic.
// E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements
return std::forward<VecType>(v);
}
};
/**
* @struct UnpadTransform
* @brief Transform to unpad data from b32 type to original type
*/
struct UnpadTransform
{
template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
{
// TODO: Implement b32 logic to unpad 32 to original data type.
return std::forward<VecType>(v);
}
};
/**
* @struct MmaDefaultTransformsGfx11
* @brief Default MMA transforms for GFX11 architecture
*/
struct MmaDefaultTransformsGfx11
{
using ATransform = DuplicateTransform;
using BTransform = DuplicateTransform;
using CTransform = PadTransform;
using DTransform = UnpadTransform;
};
/**
* @struct MmaDefaultTransformsGfx12
* @brief Default MMA transforms for GFX12 architecture
*/
struct MmaDefaultTransformsGfx12
{
using ATransform = PassThroughTransform;
using BTransform = PassThroughTransform;
using CTransform = PassThroughTransform;
using DTransform = PassThroughTransform;
};
/**
* @struct MmaTransformsDefaultSelector
* @brief Implements the default MMA transforms selection for gfx11 targets
* @tparam MmaOp Mma operation
* @tparam CompilerTarget The compiler target
*/
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx11_t<CompilerTarget>>
{
using SelectedTransforms = MmaDefaultTransformsGfx11;
};
/**
* @struct MmaTransformsDefaultSelector
* @brief Implements the default MMA transforms selection for gfx12 targets
* @tparam MmaOp Mma operation
* @tparam CompilerTarget The compiler target
*/
template <typename MmaOp, typename CompilerTarget>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
// TODO: c++20 requires
struct MmaTransformsDefaultSelector<MmaOp,
CompilerTarget,
enable_if_target_family_gfx12_t<CompilerTarget>>
{
using SelectedTransforms = MmaDefaultTransformsGfx12;
};
} // namespace ck_tile::core::arch::mma

View File

@@ -284,3 +284,242 @@
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
#define CK_TILE_ENC_SUPPORT_Y_TO_R 0
#endif
// Mark unsupported features with a deprecation warning in debug builds
#if defined(NDEBUG)
#define CK_TILE_UNSUPPORTED_IMPL(MSG)
#else
#define CK_TILE_UNSUPPORTED_IMPL(MSG) __attribute__((deprecated(MSG)))
#endif
namespace ck_tile::core {
/**
* @struct amdgcn_compiler_target_state
* @brief Defines compiler states for supported AMDGCN devices.
* @var CK_TILE_HOST_COMPILE Indicates if the compilation is for the host.
* @var CK_TILE_DEVICE_COMPILE Indicates if the compilation is for AMDGCN device.
* @var CK_TILE_ARCH_GFX908 Indicates if the compiler target architecture is GFX908.
* @var CK_TILE_ARCH_GFX90A Indicates if the compiler target architecture is GFX90A.
* @var CK_TILE_ARCH_GFX942 Indicates if the compiler target architecture is GFX942.
* @var CK_TILE_ARCH_GFX950 Indicates if the compiler target architecture is GFX950.
* @var CK_TILE_ARCH_GFX1030 Indicates if the compiler target architecture is GFX1030.
* @var CK_TILE_ARCH_GFX1031 Indicates if the compiler target architecture is GFX1031.
* @var CK_TILE_ARCH_GFX1032 Indicates if the compiler target architecture is GFX1032.
* @var CK_TILE_ARCH_GFX1034 Indicates if the compiler target architecture is GFX1034.
* @var CK_TILE_ARCH_GFX1035 Indicates if the compiler target architecture is GFX1035.
* @var CK_TILE_ARCH_GFX1036 Indicates if the compiler target architecture is GFX1036.
* @var CK_TILE_ARCH_GFX10_3_GENERIC Indicates if the compiler target architecture is GFX10.3
* generic.
* @var CK_TILE_ARCH_GFX1100 Indicates if the compiler target architecture is GFX1100.
* @var CK_TILE_ARCH_GFX1101 Indicates if the compiler target architecture is GFX1101.
* @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102.
* @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151.
* @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152.
* @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic.
* @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200.
* @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201.
* @var CK_TILE_ARCH_GFX12_GENERIC Indicates if the compiler target architecture is GFX12 generic.
*/
struct amdgcn_compiler_target_state
{
// Determine if we are compiling for device or host
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
static constexpr bool CK_TILE_DEVICE_COMPILE = true;
static constexpr bool CK_TILE_HOST_COMPILE = false;
#else
static constexpr bool CK_TILE_DEVICE_COMPILE = false;
static constexpr bool CK_TILE_HOST_COMPILE = true;
#endif // __HIP_DEVICE_COMPILE__ && __HIP_DEVICE_COMPILE__
// GFX9
#if defined(__gfx908__)
static constexpr bool CK_TILE_ARCH_GFX908 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX908 = false;
#endif // __gfx908__
#if defined(__gfx90a__)
static constexpr bool CK_TILE_ARCH_GFX90A = true;
#else
static constexpr bool CK_TILE_ARCH_GFX90A = false;
#endif // __gfx90a__
#if defined(__gfx942__)
static constexpr bool CK_TILE_ARCH_GFX942 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX942 = false;
#endif // __gfx942__
#if defined(__gfx950__)
static constexpr bool CK_TILE_ARCH_GFX950 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX950 = false;
#endif // __gfx950__
// GFX10
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1030 = false;
#endif // __gfx1030__
#if defined(__gfx1031__)
static constexpr bool CK_TILE_ARCH_GFX1031 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1031 = false;
#endif // __gfx1031__
#if defined(__gfx1032__)
static constexpr bool CK_TILE_ARCH_GFX1032 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1032 = false;
#endif // __gfx1032__
#if defined(__gfx1034__)
static constexpr bool CK_TILE_ARCH_GFX1034 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1034 = false;
#endif // __gfx1034__
#if defined(__gfx1035__)
static constexpr bool CK_TILE_ARCH_GFX1035 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1035 = false;
#endif // __gfx1035__
#if defined(__gfx1036__)
static constexpr bool CK_TILE_ARCH_GFX1036 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1036 = false;
#endif // __gfx1036__
#if defined(__gfx10_3_generic__)
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX10_3_GENERIC = false;
#endif // __gfx10_3_generic__
// GFX11
#if defined(__gfx1100__)
static constexpr bool CK_TILE_ARCH_GFX1100 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1100 = false;
#endif // __gfx1100__
#if defined(__gfx1101__)
static constexpr bool CK_TILE_ARCH_GFX1101 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1101 = false;
#endif // __gfx1101__
#if defined(__gfx1102__)
static constexpr bool CK_TILE_ARCH_GFX1102 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1102 = false;
#endif // __gfx1102__
#if defined(__gfx1103__)
static constexpr bool CK_TILE_ARCH_GFX1103 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1103 = false;
#endif // __gfx1103__
#if defined(__gfx1150__)
static constexpr bool CK_TILE_ARCH_GFX1150 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1150 = false;
#endif // __gfx1150__
#if defined(__gfx1151__)
static constexpr bool CK_TILE_ARCH_GFX1151 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1151 = false;
#endif // __gfx1151__
#if defined(__gfx1152__)
static constexpr bool CK_TILE_ARCH_GFX1152 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1152 = false;
#endif // __gfx1152__
#if defined(__gfx11_generic__)
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = false;
#endif // __gfx11_generic__
// GFX12
#if defined(__gfx1200__)
static constexpr bool CK_TILE_ARCH_GFX1200 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1200 = false;
#endif // __gfx1200__
#if defined(__gfx1201__)
static constexpr bool CK_TILE_ARCH_GFX1201 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1201 = false;
#endif // __gfx1201__
#if defined(__gfx12_generic__)
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX12_GENERIC = false;
#endif // __gfx12_generic__
};
/**
* @brief Helper to count the number of times an item is contained within a list of values
* @tparam T The type of the search value
* @tparam Ts The types of the search list values
* @param search The value to search for
* @param searchList The list of values to search in
* @return true if the search value is in the search list, false otherwise
*/
template <typename T, typename... Ts>
// TODO: c++20 concept requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >=
// 1))
CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... searchList)
{
static_assert((std::is_convertible<Ts, T>::value && ...),
"All search list values must be convertible to the search value type");
static_assert(sizeof...(Ts) >= 1, "At least one value must be provided to search in");
return (static_cast<uint32_t>(search == static_cast<T>(searchList)) + ...);
}
#define CK_TILE_COMPILER_TARGETS_LIST \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX908, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_3_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1100, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1101, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1102, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1103, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX12_GENERIC
// Sanity check: make sure only one target architecture is defined during device compile
static_assert(!amdgcn_compiler_target_state::CK_TILE_DEVICE_COMPILE ||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 1u,
"Only one target architecture can be defined during device compile");
// Sanity check: make sure no device target architecture is defined during host compile
static_assert(!amdgcn_compiler_target_state::CK_TILE_HOST_COMPILE ||
count_values_of(true, CK_TILE_COMPILER_TARGETS_LIST) == 0u,
"No device target architecture can be defined during host compile");
} // namespace ck_tile::core

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
@@ -14,6 +14,10 @@ struct ignore_t
constexpr void operator=(T&&) const noexcept
{
}
template <typename... T>
constexpr void operator()(T&&...) const noexcept
{
}
};
} // namespace detail

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -128,6 +128,25 @@ struct is_any_of<CompareTo, FirstType, Rest...>
{
};
/**
* @brief Helper to check if a value is in a list of values
* @tparam T The type of the search value
* @tparam Ts The types of the search list values
* @param search The value to search for
* @param searchList The list of values to search in
* @return true if the search value is in the search list, false otherwise
*/
template <typename T, typename... Ts>
// TODO: c++20 requires((std::is_convertible<Ts, T>::value && ...) && (sizeof...(Ts) >= 1))
CK_TILE_HOST_DEVICE static constexpr bool is_any_value_of(T search, Ts... searchList)
{
static_assert((std::is_convertible<Ts, T>::value && ...),
"All searchList values must be convertible to the type of search");
static_assert(sizeof...(Ts) >= 1, "searchList must contain at least one value");
return ((search == static_cast<T>(searchList)) || ...);
}
// Helper to check if a type is a specialization of a given template
template <typename Test, template <typename...> class RefTemplate>
struct is_specialization_of : std::false_type