mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Enable Async Copy for MI355 (#2425)
* add for async load builtin * add async load api * fix some compiling errors * fix a compiling error * fix some compiling errors * add a pipeline which copies from v4 * add a new pipeline for async load * fix some compiling errors * add async load tests * fix some issues in async load * fix * fix async inline assembly * fix async inline assembly * add ignore header file * comment some not gfx950 codes * comment some not gfx950 codes * fix a error * update async load apis * fix lds descriptor * fix a compiling error * fix some compiling errors * fix a descriptor issue * update lds descriptor * change async pipeline's tile distribution pattern from thread to warp * fix clang format * update async policy * fix a CRTP issue * fix a typo error * change lds layout * fix some sync issues * improve codes * delete the async test * fix a commented format issue * avoid compiling device functions when compile host * make gemm run * add the copy kernel support * finish the feature * Address comment * add the support for buffer_builtin * solved the merging problem * Comment Addressed --------- Co-authored-by: joye <joye@amd.com> Co-authored-by: joyeamd <John.Ye@amd.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
// This attribute gives a hint to the compiler that a branch is likely to be taken.
|
||||
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
|
||||
@@ -23,6 +24,8 @@
|
||||
#define LIKELY(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1749,7 +1752,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
@@ -1779,29 +1782,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2775,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1549,29 +1551,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr t*>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2545,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
Reference in New Issue
Block a user