diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 0932f39ca7..29cc3fefe5 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2762,11 +2762,6 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const bool is_valid, const index_t src_element_space_size) { - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; - static_assert(bytes_per_thread == dword_bytes); - const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); const int32x4_t src_resource = @@ -2783,12 +2778,27 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, "s"(src_resource) : "memory"); #else + // Direct loads require that each thread reads and writes exactly a single DWORD. +#if defined(__gfx9__) + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#endif + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx9__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif // LDS pointer must be attributed with the LDS address space. as3_uint32_ptr lds_ptr = reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index ce4af430e2..8c3bc0bc36 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2532,11 +2532,6 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const bool is_valid, const index_t src_element_space_size) { - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; - static_assert(bytes_per_thread == dword_bytes); - const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); const int32x4_t src_resource = @@ -2553,12 +2548,27 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, "s"(src_resource) : "memory"); #else + // Direct loads require that each thread reads and writes exactly a single DWORD. +#if defined(__gfx9__) + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#endif + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx9__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif // LDS pointer must be attributed with the LDS address space. as3_uint32_ptr lds_ptr = reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index e5e5ad4300..1584f706e9 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -76,7 +76,7 @@ class submodule_t: gen_header(Path(k) / (f'{km}.hpp'), kv) else: gen_header(Path(f'{k}.hpp'), v) - + submodule = submodule_t() # formatting