From bf0dc8ce56637acad3fc190a061528087ded1b81 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 3 Nov 2025 09:37:35 -0800 Subject: [PATCH] fix the compv4 and async pipeline when tile handler is 1 (#3141) [ROCm/composable_kernel commit: 057b7d43b4f1edd4bc6e881403588af8c8e96fd4] --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 17 ++++++++++++++++- .../pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 1d2a3e180b..91da3cd27b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { + if(num_loop == 1) + { + return TailNumber::One; + } if(num_loop % PrefetchStages == 1) { return TailNumber::Three; @@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync return run_func(bool_constant{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } - else + else if(TailNum == TailNumber::Two) { // 2 { @@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } + else if(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + __builtin_amdgcn_sched_barrier(0); + } return c_block_tile; } };