From 82cd975698edf45ba640c058f898998186ba20ad Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 12 Sep 2025 11:48:45 -0700 Subject: [PATCH] Fix the vector load & fix the gfx950 compv4 error (#2831) [ROCm/composable_kernel commit: 1894a0dbc304f6fd8b1d2fc9611658888baab22b] --- .../arch/amd_buffer_addressing_builtins.hpp | 39 ++++++++++++------- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 15 +++---- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 1 + 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 4013b51479..5c7ffefc6a 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1335,8 +1335,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1449,14 +1451,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } - else if constexpr(N == 8) + else { - // use fp32 load to mimic fp16 load - fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + // N >= 8: build from fp32x4 chunks + thread_buffer tmp; + static_for<0, (N / 8), 1>{}([&](auto i) { + constexpr index_t chunk = i; + tmp.template get_as()(i) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + (chunk * 4) * sizeof(float), + static_cast(coherence)); + }); return bit_cast(tmp); } } @@ -1486,13 +1493,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } - else if constexpr(N == 8) + else { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + // N >= 8: build from fp32x4 chunks + thread_buffer tmp; + static_for<0, (N / 8), 1>{}([&](auto i) { + constexpr index_t chunk = i; + tmp.template get_as()(i) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + (chunk * 4) * sizeof(float), + static_cast(coherence)); + }); return bit_cast(tmp); } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index a80ed57be5..3164b41cc7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -20,17 +20,18 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - // using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr bool single_load_tr_length = - (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == - (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); constexpr auto wg_attr_num_access = - ((is_a_load_tr || is_b_load_tr) && !single_load_tr_length) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; using WarpGemm = WarpGemmDispatcher struct WarpGemmDispatcher; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; + //WMMA cases template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_f8; }; template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8; };