From a5e36a24940ecccc402d3c8df698f5a06c23d533 Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:21:19 +0800 Subject: [PATCH] extend buffer load to support load 32 bf16/fp16 at same time (#2291) [ROCm/composable_kernel commit: 8482977a3752f0d8205d7b5530f28db3c6a3dc5f] --- .../core/arch/amd_buffer_addressing.hpp | 66 ++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 68648e1c02..7111eed596 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1437,8 +1437,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1579,6 +1581,36 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe return bit_cast(tmp); } + else if constexpr(N == 32) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return bit_cast(tmp); + } } else if constexpr(std::is_same::value) // bf16 { @@ -1633,6 +1665,36 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe return bit_cast(tmp); } + else if constexpr(N == 32) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return bit_cast(tmp); + } } else // other datatype {