diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 0fe51d5003..54edf0c353 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -140,20 +140,15 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; - static constexpr index_t RegPerFetch = - (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock / BlockSize / 4; - - static constexpr index_t MaximumPrefetchStage = (256 / RegPerFetch) > 8 ? 8 - : (256 / RegPerFetch); + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( - 92 * 1024, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); static constexpr index_t PrefetchStages = - FullMemBandPrefetchStages >= 2 ? FullMemBandPrefetchStages <= MaximumPrefetchStage - ? FullMemBandPrefetchStages - : MaximumPrefetchStage - : 2; + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; @@ -635,10 +630,11 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( - 92 * 1024, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); static constexpr index_t PrefetchStages = FullMemBandPrefetchStages >= 2 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8