mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)
* WMMA GEMM F16 Implementation
Signed-off-by: root <tianyuwu@amd.com>
* Self-review
Signed-off-by: root <tianyuwu@amd.com>
* ASIC check minor tweak
Signed-off-by: root <tianyuwu@amd.com>
* add missing include file
* Set GPU_TARGETS to gfx11/12 generic
Signed-off-by: root <tianyuwu@amd.com>
* INT8 GFX12
Signed-off-by: root <tianyuwu@amd.com>
* add int8x16 branch
* Fix CI script
Signed-off-by: root <tianyuwu@amd.com>
* Fix typo
Signed-off-by: root <tianyuwu@amd.com>
* Add CK_Tile WMMA example
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Fix CI
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* fix clang format
* Set M/N_Warp Back to Constant
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Use GemmConfigComputeV3 by default
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Remove CK_Tile wmma gemm examples from the CI list
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add atomic add fallback method for gfx11
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Omit copyright year
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Support non-square cases
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add get_device_ip()
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Add atomic add fallback method for gfx11"
This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"
This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.
* Revise method name and typos
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Try fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Try fix CI"
This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo caused by merge
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Fix typo caused by merging
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
---------
Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
@@ -59,7 +60,7 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
@@ -230,4 +231,20 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_
|
||||
}
|
||||
}
|
||||
|
||||
// Architecture tags
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
struct gfx12_t
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#else // if defined(__gfx12__)
|
||||
return gfx12_t{};
|
||||
#endif
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp16x2_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32F162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* f162_a;
|
||||
};
|
||||
|
||||
union U32F162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t f162;
|
||||
};
|
||||
|
||||
U32F162_ADDR dword_addr;
|
||||
U32F162 cur_v;
|
||||
U32F162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.f162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
|
||||
x.template get_as<fp16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
|
||||
@@ -152,7 +152,7 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
@@ -274,6 +274,12 @@
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WAVE32_ENABLED
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_TILE_WAVE32_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
|
||||
Reference in New Issue
Block a user