diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 05775063b8..add6b1dbdc 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1783,60 +1783,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template