mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fix pre-hook commit error
This commit is contained in:
@@ -132,12 +132,12 @@ struct GemmPersistentAsyncInvoker
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
|
||||
// Add persistent async arguments to ws_args
|
||||
ws_args.chunk_signals = async_args.chunk_signals;
|
||||
ws_args.chunk_signals = async_args.chunk_signals;
|
||||
ws_args.tiles_per_chunk_m = async_args.tiles_per_chunk_m;
|
||||
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
|
||||
@@ -38,7 +38,7 @@ struct GemmHostArgs
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_E_,
|
||||
uint32_t* chunk_signals_ = nullptr,
|
||||
uint32_t* chunk_signals_ = nullptr,
|
||||
index_t tiles_per_chunk_m_ = 0)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
@@ -76,7 +76,7 @@ struct GemmHostArgs
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
|
||||
|
||||
// Persistent async arguments
|
||||
uint32_t* chunk_signals;
|
||||
index_t tiles_per_chunk_m;
|
||||
|
||||
@@ -18,10 +18,10 @@ namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Wait for a signal to become ready with acquire semantics
|
||||
*
|
||||
*
|
||||
* Producer-only wait: One lane polls chunk_signals[chunk_idx] with acquire semantics,
|
||||
* then a workgroup barrier releases everyone.
|
||||
*
|
||||
*
|
||||
* @param signal_addr Pointer to the signal location in device memory
|
||||
*/
|
||||
CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr)
|
||||
@@ -47,14 +47,14 @@ CK_TILE_DEVICE static inline void wait_signal(uint32_t* signal_addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Workgroup barrier to release all threads after signal is ready
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fence for safe iteration boundaries in persistent loops
|
||||
*
|
||||
*
|
||||
* Ensures all memory operations are complete before reusing LDS or moving to next tile.
|
||||
* Uses s_waitcnt vmcnt=0, lgkmcnt=0 + s_barrier.
|
||||
*/
|
||||
@@ -62,10 +62,10 @@ CK_TILE_DEVICE static inline void iteration_boundary_fence()
|
||||
{
|
||||
// Wait for all vector memory operations (global memory loads/stores)
|
||||
__builtin_amdgcn_s_waitcnt_vmcnt(0);
|
||||
|
||||
|
||||
// Wait for all LDS operations
|
||||
__builtin_amdgcn_s_waitcnt_lgkmcnt(0);
|
||||
|
||||
|
||||
// Synchronize all threads in the workgroup
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
@@ -96,7 +96,7 @@ struct UniversalGemmHostArgs
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_,
|
||||
uint32_t* chunk_signals_ = nullptr,
|
||||
uint32_t* chunk_signals_ = nullptr,
|
||||
index_t tiles_per_chunk_m_ = 0)
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
@@ -136,7 +136,7 @@ struct UniversalGemmHostArgs
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
|
||||
|
||||
// Persistent async arguments
|
||||
uint32_t* chunk_signals;
|
||||
index_t tiles_per_chunk_m;
|
||||
@@ -173,7 +173,7 @@ struct UniversalGemmKernelArgs
|
||||
/// (in memory) of E tensor.
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
|
||||
|
||||
/// @brief Pointer to chunk signals for async producer-consumer synchronization.
|
||||
/// chunk_signals[i] == 1 indicates that chunk i is ready.
|
||||
uint32_t* chunk_signals;
|
||||
@@ -1210,7 +1210,7 @@ struct UniversalGemmKernel
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
|
||||
// Producer-consumer synchronization: wait for chunk to be ready
|
||||
if(kargs.chunk_signals != nullptr && kargs.tiles_per_chunk_m > 0)
|
||||
{
|
||||
@@ -1283,11 +1283,11 @@ struct UniversalGemmKernel
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Safe iteration boundary: ensure all memory operations complete
|
||||
// before reusing LDS or moving to next tile
|
||||
iteration_boundary_fence();
|
||||
|
||||
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
if(block_id >= num_work)
|
||||
|
||||
Reference in New Issue
Block a user