mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Refactor async loads to work on all GPUs (#2545)
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
committed by
GitHub
parent
821cd26c13
commit
cbfa62e4b6
@@ -1783,60 +1783,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
|
||||
// Remove this assert once other sizes are implemented.
|
||||
assert(src_immediate_addr_offset == 0 &&
|
||||
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
|
||||
ignore = src_immediate_addr_offset;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
src_wave_addr_offset = 0;
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Set up v_offset:
|
||||
index_t v_offset = src_thread_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -1553,60 +1553,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
|
||||
// Remove this assert once other sizes are implemented.
|
||||
assert(src_immediate_addr_offset == 0 &&
|
||||
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
|
||||
ignore = src_immediate_addr_offset;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
src_wave_addr_offset = 0;
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Set up v_offset:
|
||||
index_t v_offset = src_thread_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
|
||||
Reference in New Issue
Block a user