From 8b842250daf8b1fab2bb8a7f250974cf6e06080e Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:37:09 -0800 Subject: [PATCH] Add persistent async input scheduler for GEMM kernels (#3520) Add signal-based synchronization for persistent GEMM kernels where input data becomes available incrementally. Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation: chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks Key components: - PersistentAsyncInputScheduler struct with tiles_per_chunk_m, chunk_signals, tile_idx_pivot_m, and num_chunks fields - wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency - IsSupportedArgument validation for scheduler parameters - Example demonstrating async input scheduling with simulated producer - GTest unit tests covering all layout combinations [ROCm/composable_kernel commit: 91b4102a59c6013d3faeb54f250cf577b2f129ce] --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 3 +- example/ck_tile/03_gemm/universal_gemm.cpp | 229 +++++++++++-- .../03_gemm/universal_gemm_invoker.hpp | 170 ++++++++++ include/ck_tile/core.hpp | 1 + .../ck_tile/core/arch/workgroup_barrier.hpp | 30 ++ .../persistent_async_input_scheduler.hpp | 49 +++ .../ops/gemm/kernel/universal_gemm_kernel.hpp | 98 ++++-- test/ck_tile/CMakeLists.txt | 1 + .../CMakeLists.txt | 19 ++ .../test_gemm_persistent_async_input.cpp | 304 ++++++++++++++++++ 11 files changed, 844 insertions(+), 61 deletions(-) create mode 100644 include/ck_tile/core/utility/persistent_async_input_scheduler.hpp create mode 100644 test/ck_tile/gemm_persistent_async_input/CMakeLists.txt create mode 100644 test/ck_tile/gemm_persistent_async_input/test_gemm_persistent_async_input.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 066dc9aa3b..c3a257e464 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for gfx1153 target. * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. +* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. ### Changed diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469..c1df27ecc8 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -456,7 +456,8 @@ inline auto create_args() .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "gemm.json", "json file name to dump results") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") - .insert("rotating_count", "1000", "rotating count, defaults to 1000"); + .insert("rotating_count", "1000", "rotating count, defaults to 1000") + .insert("test_async", "0", "0: normal gemm, 1: test async input scheduler"); return arg_parser; } diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c1c8a2fc89..ace9152747 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,6 +12,169 @@ #include "run_gemm_example_common.hpp" #include "universal_gemm_invoker.hpp" +// Universal GEMM-specific wrapper that handles test_async flag +template +int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser, + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const CLayout c_layout = CLayout{}) +{ + using Invoker = UniversalInvoker; + using AccDataType = typename GemmTypeConfig::AccDataType; + + // Check for async input scheduler test mode + bool test_async = arg_parser.get_int("test_async"); + if(test_async) + { + // Extract parameters for async test (same as shared implementation) + const ck_tile::index_t M = arg_parser.get_int("m"); + const ck_tile::index_t N = arg_parser.get_int("n"); + const ck_tile::index_t K = arg_parser.get_int("k"); + const ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + constexpr bool is_a_row_major = std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + constexpr bool is_c_row_major = std::is_same_v; + + const ck_tile::index_t stride_A = is_a_row_major ? K : M; + const ck_tile::index_t stride_B = is_b_row_major ? N : K; + const ck_tile::index_t stride_C = is_c_row_major ? N : M; + + // Allocate and initialize tensors + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + M, K, stride_A, ck_tile::bool_constant{})); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + K, N, stride_B, ck_tile::bool_constant{})); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + M, N, stride_C, ck_tile::bool_constant{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; + + Invoker::template test_async_input_scheduler, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough>( + args, ck_tile::stream_config{nullptr, false, 1}); + + // Copy result from device for verification + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + // Compute CPU reference + ck_tile::HostTensor c_m_n_ref(ck_tile::host_tensor_descriptor( + M, N, stride_C, ck_tile::bool_constant{})); + c_m_n_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + + // Verify results + const float max_accumulated_value = + *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + + std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl; + return pass; + } + + // Normal path - delegate to shared implementation + return run_gemm_example_with_layouts( + arg_parser, a_layout, b_layout, c_layout); +} + +// Universal GEMM-specific prec_type dispatcher that uses the wrapper +template +int run_gemm_example_prec_type_universal(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && std::is_same_v) + { + throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); + } + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + using LayoutVariant = std::variant; + + auto string_to_layout = [](const std::string& layout) -> LayoutVariant { + if(layout == "R") + return Row{}; + if(layout == "C") + return Col{}; + throw std::runtime_error("Unsupported layout: " + layout); + }; + + auto a_layout_variant = string_to_layout(a_layout); + auto b_layout_variant = string_to_layout(b_layout); + + return std::visit( + [&](auto a_layout_type, auto b_layout_type) -> int { + if constexpr(std::is_same_v && + std::is_same_v) + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + else + { + return run_gemm_example_with_layouts_universal( + arg_parser, a_layout_type, b_layout_type, Row{}); + } + }, + a_layout_variant, + b_layout_variant); +} + template