diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index f5166cfdcb..d5ba324326 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1226,23 +1226,37 @@ struct UniversalGemmKernel s_waitcnt_barrier(); const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); - const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + // Apply pivot to M tile index first, then use the same pivoted index + // for both data-tile selection and chunk-signal wait. + auto iM_eff = amd_wave_read_first_lane(iM); + + if(kargs.async_input_scheduler.chunk_signals != nullptr) + { + const auto tile_idx_pivot = + amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m); + const auto tiles_m = amd_wave_read_first_lane( + integer_divide_ceil(kargs.M, TilePartitioner::MPerBlock)); + if(tiles_m > 0) + { + iM_eff = amd_wave_read_first_lane((iM_eff + tile_idx_pivot) % tiles_m); + } + } + + const index_t i_m = amd_wave_read_first_lane(iM_eff * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); // Synchronize with producer to ensure input data is ready before processing tile if(kargs.async_input_scheduler.chunk_signals != nullptr) { const auto tiles_per_chunk = amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m); - const auto tile_idx_pivot = - amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m); const auto num_chunks = amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks); if(tiles_per_chunk > 0 && num_chunks > 0) { // Pivot allows rotating chunk assignments for load balancing - const auto chunk_idx = amd_wave_read_first_lane( - ((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks); + const auto chunk_idx = + amd_wave_read_first_lane((iM_eff / tiles_per_chunk) % num_chunks); workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals); chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx); }