diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 37f12e5dc7..4013b51479 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1335,10 +1335,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1451,19 +1449,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } - else + else if constexpr(N == 8) { - // N >= 8: build from fp32x4 chunks - thread_buffer tmp; + // use fp32 load to mimic fp16 load + fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); - static_for<0, (N / 4), 1>{}([&](auto i) { - constexpr index_t chunk = i; - tmp.template get_as()(i) = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + (chunk * 4) * sizeof(float), - static_cast(coherence)); - }); return bit_cast(tmp); } } @@ -1502,56 +1495,6 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe return bit_cast(tmp); } - else if constexpr(N == 16) - { - // use fp32 load to mimic bfp16 load - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 32) - { - // use fp32 load to mimic bfp16 load - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<2>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<3>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(float), - static_cast(coherence)); - - return bit_cast(tmp); - } } else // other datatype {