diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp old mode 100755 new mode 100644 index 450ceaa5eb..70b63d2ca4 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp @@ -36,7 +36,8 @@ * @param chunk_idx Index of chunk to signal * @param stream HIP stream for async operations */ -[[maybe_unused]] static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) +[[maybe_unused]] static void +signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream) { uint32_t ready = 1; ck_tile::hip_check_error(hipMemcpyAsync( @@ -67,7 +68,7 @@ int main(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); - + auto res = invoke_grouped_gemm_persistent_async( a_layout, b_layout, data_type, arg_parser, , tiles_per_chunk_m, tile_idx_pivot_m); @@ -76,6 +77,5 @@ int main(int argc, char* argv[]) */ - return 0; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp old mode 100755 new mode 100644 index 814f45900e..599ec70746 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp @@ -3,106 +3,112 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/epilogue.hpp" +template +void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + constexpr bool TransposeC = false; + constexpr bool DoubleSmemBuffer = false; -template - void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + }; + + if(splitk) { - constexpr bool TransposeC = false; - constexpr bool DoubleSmemBuffer = false; - - constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using GemmUniversalTraits = - ck_tile::PersistentTileGemmUniversalTraits; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; - - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(splitk) - { - Run(ck_tile::integral_constant{}); - } - else - { - - Run(ck_tile::integral_constant{}); - } + Run(ck_tile::integral_constant{}); } + else + { + + Run(ck_tile::integral_constant{}); + } +} diff --git a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp old mode 100755 new mode 100644 index 9c6c381196..c8533d948b --- a/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp +++ b/example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp @@ -43,19 +43,20 @@ CK_TILE_DEVICE static void wait_chunk_signal(const uint32_t* chunk_signals, inde if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { volatile const uint32_t* signal_ptr = chunk_signals + chunk_idx; - + // Poll until chunk is ready (signal == 1) // Use acquire semantics for proper memory ordering uint32_t signal_value; - do { + do + { signal_value = __builtin_nontemporal_load(signal_ptr); __builtin_amdgcn_s_sleep(1); // Brief sleep to reduce contention } while(signal_value == 0); - + // Memory fence with acquire semantics __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent"); } - + // Barrier to release all threads in the workgroup __builtin_amdgcn_s_barrier(); }