From 7eaa398458e0b4920798da44623cbd362c276bd8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Sun, 15 Jun 2025 15:22:34 -0700 Subject: [PATCH] Fix direct lds load for gfx950 and clang20 (#2346) * fix direct lds load for gfx950 and clang20 * Update include/ck/utility/amd_buffer_addressing_builtins.hpp * Fix format --------- Co-authored-by: Aviral Goel Co-authored-by: Andriy Roshchenko [ROCm/composable_kernel commit: 2d8a804152ebaa36775fea393227cb956e6e550e] --- .../utility/amd_buffer_addressing_builtins.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index 1836e9461d..f642e06050 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -402,7 +402,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -838,10 +838,18 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const bool is_valid, const index_t src_element_space_size) { - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx942__) + constexpr auto dword_bytes = 4; static_assert(bytes_per_thread == dword_bytes); +#endif const int32x4_t src_resource = make_wave_buffer_resource(global_base_ptr, src_element_space_size); @@ -872,7 +880,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } #endif