diff --git a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp index 4877853d2c..d9202a10d3 100644 --- a/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_persistent_async_invoker.hpp @@ -132,12 +132,12 @@ struct GemmPersistentAsyncInvoker ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); auto c_ptr = ws_args.c_ptr; ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - + // Add persistent async arguments to ws_args - ws_args.chunk_signals = async_args.chunk_signals; + ws_args.chunk_signals = async_args.chunk_signals; ws_args.tiles_per_chunk_m = async_args.tiles_per_chunk_m; - - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) : GemmKernel::GridSize(args.M, args.N, args.k_batch); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 76763a7c25..8dba7f9792 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -38,7 +38,7 @@ struct GemmHostArgs index_t stride_A_, index_t stride_B_, index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, + uint32_t* chunk_signals_ = nullptr, index_t tiles_per_chunk_m_ = 0) : a_ptr(a_ptr_), b_ptr(b_ptr_), @@ -76,7 +76,7 @@ struct GemmHostArgs }; index_t k_batch; - + // Persistent async arguments uint32_t* chunk_signals; index_t tiles_per_chunk_m; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index c71b6a2ff2..b50197f7d4 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -18,10 +18,10 @@ namespace ck_tile { /** * @brief Wait for a signal to become ready with acquire semantics - * + * * Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics, * then a workgroup barrier releases everyone. - * + * * @param signal_addr Pointer to the signal location in device memory */ CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) @@ -47,14 +47,14 @@ CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr) } } } - + // Workgroup barrier to release all threads after signal is ready __builtin_amdgcn_s_barrier(); } /** * @brief Fence for safe iteration boundaries in persistent loops - * + * * Ensures all memory operations are complete before reusing LDS or moving to next tile. * Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier. */ @@ -62,10 +62,10 @@ CK_TILE_DEVICE static inline void iteration_boundary_fence() { // Wait for all vector memory operations (global memory loads/stores) __builtin_amdgcn_s_waitcnt_vmcnt(0); - + // Wait for all LDS operations __builtin_amdgcn_s_waitcnt_lgkmcnt(0); - + // Synchronize all threads in the workgroup __builtin_amdgcn_s_barrier(); } @@ -96,7 +96,7 @@ struct UniversalGemmHostArgs const std::array& stride_Bs_, const std::array& stride_Ds_, index_t stride_E_, - uint32_t* chunk_signals_ = nullptr, + uint32_t* chunk_signals_ = nullptr, index_t tiles_per_chunk_m_ = 0) : as_ptr(as_ptr_), bs_ptr(bs_ptr_), @@ -136,7 +136,7 @@ struct UniversalGemmHostArgs }; index_t k_batch; - + // Persistent async arguments uint32_t* chunk_signals; index_t tiles_per_chunk_m; @@ -173,7 +173,7 @@ struct UniversalGemmKernelArgs /// (in memory) of E tensor. index_t stride_E; index_t k_batch; - + /// @brief Pointer to chunk signals for async producer-consumer synchronization. /// chunk_signals[i] == 1 indicates that chunk i is ready. uint32_t* chunk_signals; @@ -1210,7 +1210,7 @@ struct UniversalGemmKernel const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - + // Producer-consumer synchronization: wait for chunk to be ready if(kargs.chunk_signals != nullptr && kargs.tiles_per_chunk_m > 0) { @@ -1283,11 +1283,11 @@ struct UniversalGemmKernel i_n); } } - + // Safe iteration boundary: ensure all memory operations complete // before reusing LDS or moving to next tile iteration_boundary_fence(); - + // Advance to the next work item block_id += grid_size; if(block_id >= num_work)