From a2b87dd39d3a33d7739df1bfaa8fda8b37933714 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 22 Aug 2025 10:13:47 +0800 Subject: [PATCH] [CK_TILE][FMHA] Enable dwordx4 loading in async_load_tile_raw() (#2549) * Support async load dwordx4 * Enlarge load size on gfx950 [ROCm/composable_kernel commit: 4a7ecce096fa9008934b38336bc2ea4f2066a16d] --- .../core/arch/amd_buffer_addressing.hpp | 73 ++++++++++++------- .../arch/amd_buffer_addressing_builtins.hpp | 73 ++++++++++++------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 8 +- 3 files changed, 103 insertions(+), 51 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 07be65a150..037e86909d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1276,26 +1276,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); -template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1766,15 +1786,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) +template +CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) +#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ + : "memory"); \ + else \ + asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(voffset), "s"(rsrc), "n"(ioffset) \ : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword"); + } +#if defined(__gfx950__) + else if constexpr(num_dwords == 3) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4"); + } +#endif else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); + { + static_assert(false, "wrong! not implemented data width"); + } +#undef CK_TILE_ASYNC_LOAD_WITH_INSTR } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1536,15 +1556,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, index_t src_immediate_addr_offset = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + constexpr index_t num_bytes = sizeof(T) * N; + constexpr index_t num_words = num_bytes / 4; + static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4), + "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + async_buffer_load_dwordxn_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template ; if constexpr(AsyncCopy) { - return 4 / sizeof(KDataType); +#if defined(__gfx950__) + constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4 +#else + constexpr index_t MaxLoadSizeInBytes = 4; // dword +#endif + + return MaxLoadSizeInBytes / sizeof(KDataType); } else {