[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)

* WMMA GEMM F16 Implementation

Signed-off-by: root <tianyuwu@amd.com>

* Self-review

Signed-off-by: root <tianyuwu@amd.com>

* ASIC check minor tweak

Signed-off-by: root <tianyuwu@amd.com>

* add missing include file

* Set GPU_TARGETS to gfx11/12 generic

Signed-off-by: root <tianyuwu@amd.com>

* INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>

* add int8x16 branch

* Fix CI script

Signed-off-by: root <tianyuwu@amd.com>

* Fix typo

Signed-off-by: root <tianyuwu@amd.com>

* Add CK_Tile WMMA example

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Fix CI

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* fix clang format

* Set M/N_Warp Back to Constant

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Use GemmConfigComputeV3 by default

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Remove CK_Tile wmma gemm examples from the CI list

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add atomic add fallback method for gfx11

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Omit copyright year

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Support non-square cases

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add get_device_ip()

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Add atomic add fallback method for gfx11"

This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"

This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.

* Revise method name and typos

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Try fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Try fix CI"

This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo caused by merge

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Fix typo caused by merging

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>

[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
Tianyuan Wu
2025-08-16 07:22:27 +08:00
committed by GitHub
parent 42d775e488
commit ec7ee5b7b7
54 changed files with 1388 additions and 403 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
@@ -59,7 +60,7 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
return 64;
#else
return 32;
@@ -230,4 +231,20 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_
}
}
// Architecture tags
struct gfx11_t
{
};
struct gfx12_t
{
};
CK_TILE_DEVICE static constexpr auto get_device_arch()
{
#if defined(__gfx11__)
return gfx11_t{};
#else // if defined(__gfx12__)
return gfx12_t{};
#endif
}
} // namespace ck_tile

View File

@@ -6,6 +6,10 @@
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
namespace ck_tile {
template <typename T, typename ComputeType>
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
return rtn;
}
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
} while(cur_v.u64 != old_v);
}
//
// Atomic add for fp16x2_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
#else
union U32F162_ADDR
{
uint32_t* u32_a;
fp16x2_t* f162_a;
};
union U32F162
{
uint32_t u32;
fp16x2_t f162;
};
U32F162_ADDR dword_addr;
U32F162 cur_v;
U32F162 new_;
uint32_t old_v, new_v;
dword_addr.f162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.f162 = add_f16x2_t(cur_v.f162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
x.template get_as<fp16x2_t>()[i]);
});
}
}
template <typename T, index_t N>

View File

@@ -152,7 +152,7 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx9__) // for GPU code
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
@@ -274,6 +274,12 @@
#define CK_TILE_WA_ISSUE_2028 0
#endif
#ifndef CK_TILE_WAVE32_ENABLED
#if defined(__gfx11__) || defined(__gfx12__)
#define CK_TILE_WAVE32_ENABLED
#endif
#endif
// Y pointed to R, we don't see a valuable use case.
// Will enforce encoding to check Y not pointed to R if set to zero
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R