[rocm-libraries] ROCm/rocm-libraries#4280 (commit b7de1e1)

[CK_TILE] Add blockscale GEMM support for EightWarps on
 gfx950 (#4280)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

gemm blockscale eightwarps support

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
kensclin
2026-02-09 03:55:52 +00:00
committed by assistant-librarian[bot]
parent 731afe535a
commit 5b3e527c88
19 changed files with 1881 additions and 225 deletions

View File

@@ -103,6 +103,12 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr,
}
return r;
}
CK_TILE_DEVICE __amdgpu_buffer_rsrc_t make_builtin_buffer_resource(const void* ptr,
uint32_t size = 0xffffffff)
{
return __builtin_amdgcn_make_buffer_rsrc(
const_cast<void*>(ptr), /*stride*/ 0, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD);
}
namespace impl {
// below type indicate the data type used for buffer load inline asm
@@ -1735,27 +1741,22 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
index_t IMM = 0>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
const __amdgpu_buffer_rsrc_t rsrc,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
index_t src_wave_addr_offset = 0,
number<IMM> /*src_immediate_addr_offset*/ = {},
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
static_assert(IMM < (1 << 12), "wrong! immediate offset too large");
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
#endif
@@ -1763,18 +1764,18 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
v_offset = flag ? v_offset : 0x7fffffff; // large offset to cause OOB access
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_ptr_buffer_load_lds(rsrc,
smem,
bytes,
v_offset,
src_wave_addr_offset,
/*imm*/ IMM,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}
@@ -2585,22 +2586,24 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
bool oob_conditional_check = false,
typename linear_offset_t>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
const __amdgpu_buffer_rsrc_t rsrc,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_wave_addr_offset,
linear_offset_t,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
constexpr index_t src_linear_addr_offset = static_cast<index_t>(linear_offset_t{}) * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
rsrc,
src_thread_addr_offset,
0,
src_linear_addr_offset,
src_wave_addr_offset,
number<src_linear_addr_offset>{},
is_valid_element,
bool_constant<oob_conditional_check>{});
}

View File

@@ -1016,6 +1016,11 @@ CK_TILE_DEVICE void s_waitcnt()
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
#endif
}
template <index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_lgkm()
{
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
}
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,