From 8afac88f89ade4e732ae58a66c3c360bd20e53ce Mon Sep 17 00:00:00 2001 From: joye Date: Fri, 23 May 2025 17:13:10 +0800 Subject: [PATCH] fix f4 pipeline issues --- ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index bb3b83bf15..357b91373b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -417,7 +417,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto m0) { + static_for<0, math::min(2 * MXdlPack, MRepeat), 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops); @@ -465,9 +465,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto m0) { @@ -512,10 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}([&](auto m0) { + if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages)) + { + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + constexpr auto lds_buf = + m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages) + ? scale_mem_buf + : scale_comp_buf; + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { constexpr index_t a_scale_offset = @@ -602,15 +614,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle= (MRepeat/ MXdlPack - LocalPrefetchStages) - ? scale_mem_buf - : scale_comp_buf; static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { @@ -642,6 +645,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}, I1, Number{}, Number{}, Number{}), - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - I1)); +// Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] + // Order: 1 0 3 2 4 + static constexpr auto ARegBuf = 2; + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + I1)); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4