mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4280 (commit b7de1e1)
[CK_TILE] Add blockscale GEMM support for EightWarps on gfx950 (#4280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes gemm blockscale eightwarps support ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [x] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
731afe535a
commit
5b3e527c88
@@ -103,6 +103,12 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr,
|
||||
}
|
||||
return r;
|
||||
}
|
||||
CK_TILE_DEVICE __amdgpu_buffer_rsrc_t make_builtin_buffer_resource(const void* ptr,
|
||||
uint32_t size = 0xffffffff)
|
||||
{
|
||||
return __builtin_amdgcn_make_buffer_rsrc(
|
||||
const_cast<void*>(ptr), /*stride*/ 0, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
@@ -1735,27 +1741,22 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
bool oob_conditional_check = true,
|
||||
index_t IMM = 0>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
const __amdgpu_buffer_rsrc_t rsrc,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
index_t src_wave_addr_offset = 0,
|
||||
number<IMM> /*src_immediate_addr_offset*/ = {},
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
|
||||
// Remove this assert once other sizes are implemented.
|
||||
assert(src_immediate_addr_offset == 0 &&
|
||||
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
|
||||
ignore = src_immediate_addr_offset;
|
||||
static_assert(IMM < (1 << 12), "wrong! immediate offset too large");
|
||||
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
src_wave_addr_offset = 0;
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
#endif
|
||||
@@ -1763,18 +1764,18 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
// Set up v_offset:
|
||||
index_t v_offset = src_thread_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
v_offset = flag ? v_offset : 0x7fffffff; // large offset to cause OOB access
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// Use C-style cast to change address space without dropping llvm noalias attribute
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
__builtin_amdgcn_raw_ptr_buffer_load_lds(rsrc,
|
||||
smem,
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*imm*/ IMM,
|
||||
static_cast<index_t>(coherence));
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
@@ -2585,22 +2586,24 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = false>
|
||||
bool oob_conditional_check = false,
|
||||
typename linear_offset_t>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
const __amdgpu_buffer_rsrc_t rsrc,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
linear_offset_t,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
constexpr index_t src_linear_addr_offset = static_cast<index_t>(linear_offset_t{}) * sizeof(T);
|
||||
|
||||
amd_async_buffer_load<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
rsrc,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
number<src_linear_addr_offset>{},
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
@@ -1016,6 +1016,11 @@ CK_TILE_DEVICE void s_waitcnt()
|
||||
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
|
||||
#endif
|
||||
}
|
||||
template <index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
|
||||
CK_TILE_DEVICE void s_waitcnt_lgkm()
|
||||
{
|
||||
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
|
||||
}
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
index_t expcnt = waitcnt_arg::kMaxExpCnt,
|
||||
|
||||
Reference in New Issue
Block a user