mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add persistent async input scheduler for GEMM kernels (#3520)
Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks
Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations
[ROCm/composable_kernel commit: 91b4102a59]
This commit is contained in:
@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for gfx1153 target.
|
||||
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
|
||||
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
|
||||
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -456,7 +456,8 @@ inline auto create_args()
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "gemm.json", "json file name to dump results")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
|
||||
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,169 @@
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "universal_gemm_invoker.hpp"
|
||||
|
||||
// Universal GEMM-specific wrapper that handles test_async flag
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using Invoker = UniversalInvoker;
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
// Check for async input scheduler test mode
|
||||
bool test_async = arg_parser.get_int("test_async");
|
||||
if(test_async)
|
||||
{
|
||||
// Extract parameters for async test (same as shared implementation)
|
||||
const ck_tile::index_t M = arg_parser.get_int("m");
|
||||
const ck_tile::index_t N = arg_parser.get_int("n");
|
||||
const ck_tile::index_t K = arg_parser.get_int("k");
|
||||
const ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
|
||||
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
|
||||
|
||||
const ck_tile::index_t stride_A = is_a_row_major ? K : M;
|
||||
const ck_tile::index_t stride_B = is_b_row_major ? N : K;
|
||||
const ck_tile::index_t stride_C = is_c_row_major ? N : M;
|
||||
|
||||
// Allocate and initialize tensors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
Invoker::template test_async_input_scheduler<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough>(
|
||||
args, ck_tile::stream_config{nullptr, false, 1});
|
||||
|
||||
// Copy result from device for verification
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Compute CPU reference
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
c_m_n_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
|
||||
// Verify results
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
|
||||
std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Normal path - delegate to shared implementation
|
||||
return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
|
||||
arg_parser, a_layout, b_layout, c_layout);
|
||||
}
|
||||
|
||||
// Universal GEMM-specific prec_type dispatcher that uses the wrapper
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type_universal(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
using LayoutVariant = std::variant<Row, Col>;
|
||||
|
||||
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
|
||||
if(layout == "R")
|
||||
return Row{};
|
||||
if(layout == "C")
|
||||
return Col{};
|
||||
throw std::runtime_error("Unsupported layout: " + layout);
|
||||
};
|
||||
|
||||
auto a_layout_variant = string_to_layout(a_layout);
|
||||
auto b_layout_variant = string_to_layout(b_layout);
|
||||
|
||||
return std::visit(
|
||||
[&](auto a_layout_type, auto b_layout_type) -> int {
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
|
||||
std::is_same_v<decltype(b_layout_type), Row>)
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_with_layouts_universal<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, a_layout_type, b_layout_type, Row{});
|
||||
}
|
||||
},
|
||||
a_layout_variant,
|
||||
b_layout_variant);
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -19,52 +182,50 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
using Invoker = UniversalInvoker;
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
Invoker,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
Invoker,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -75,11 +236,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -90,11 +251,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
|
||||
struct UniversalInvoker
|
||||
{
|
||||
@@ -150,4 +154,170 @@ struct UniversalInvoker
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
true, // Persistent = true for async test
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
const ck_tile::index_t tiles_m =
|
||||
ck_tile::integer_divide_ceil(args.M, TilePartitioner::MPerBlock);
|
||||
// Balance signal granularity (smaller chunks = finer control) vs overhead (more signals)
|
||||
const ck_tile::index_t tiles_per_chunk = 2;
|
||||
// Shift chunk assignments to test wraparound behavior
|
||||
const ck_tile::index_t tile_idx_pivot = tiles_per_chunk;
|
||||
// Account for pivot when allocating signal buffer
|
||||
const ck_tile::index_t num_chunks =
|
||||
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
|
||||
|
||||
std::cout << "Async Input Scheduler Test:" << std::endl;
|
||||
std::cout << " M tiles: " << tiles_m << std::endl;
|
||||
std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl;
|
||||
std::cout << " Tile index pivot: " << tile_idx_pivot << std::endl;
|
||||
std::cout << " Number of signal chunks: " << num_chunks << std::endl;
|
||||
|
||||
// Signals must start as zero so kernel blocks until producer sets them
|
||||
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
|
||||
signal_buf.SetZero();
|
||||
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
|
||||
|
||||
// Setup async input scheduler
|
||||
ck_tile::PersistentAsyncInputScheduler async_scheduler;
|
||||
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
|
||||
async_scheduler.chunk_signals = d_chunk_signals;
|
||||
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
|
||||
async_scheduler.num_chunks = num_chunks;
|
||||
|
||||
// Create modified host args with async scheduler
|
||||
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr},
|
||||
{args.b_ptr},
|
||||
{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
{args.stride_A},
|
||||
{args.stride_B},
|
||||
{},
|
||||
args.stride_E,
|
||||
async_scheduler);
|
||||
|
||||
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
// Separate stream prevents deadlock: kernel and signal producer must run concurrently
|
||||
hipStream_t signal_stream;
|
||||
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// Simulate incremental input arrival by delaying signal activation
|
||||
const int sleep_us = 100;
|
||||
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
|
||||
const uint32_t signal_val = 1;
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
|
||||
&signal_val,
|
||||
sizeof(uint32_t),
|
||||
hipMemcpyHostToDevice,
|
||||
signal_stream));
|
||||
}
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
|
||||
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
|
||||
|
||||
// Wait for kernel completion
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::high_resolution_clock::now() - start);
|
||||
|
||||
std::cout << " Total time: " << duration.count() << " us" << std::endl;
|
||||
std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -91,6 +91,7 @@
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
|
||||
@@ -26,6 +26,36 @@ struct workgroup_barrier
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Reduces power consumption during polling by leveraging wave-level sleep instructions
|
||||
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
// Limit active polling to first wave to reduce memory traffic and power
|
||||
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
|
||||
if(threadIdx.x < wave_size)
|
||||
{
|
||||
uint32_t loaded_value = 0;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
|
||||
while(loaded_value != value)
|
||||
{
|
||||
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
|
||||
// busy-wait
|
||||
__builtin_amdgcn_s_sleep(1);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Scheduler for persistent GEMM kernels with asynchronous input streaming.
|
||||
///
|
||||
/// This structure enables signal-based synchronization for persistent kernels where input data
|
||||
/// becomes available incrementally. It divides M-dimension tiles into chunks and uses signals
|
||||
/// to coordinate between the input producer and the kernel consumer.
|
||||
///
|
||||
/// Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation:
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
///
|
||||
/// @par Typical usage pattern:
|
||||
/// 1. Set tiles_per_chunk_m to group tiles into chunks (e.g., 2 or 4 tiles per chunk)
|
||||
/// 2. Set tile_idx_pivot_m as offset for chunk calculation
|
||||
/// 3. Set num_chunks = ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)
|
||||
/// 4. Allocate chunk_signals array with size = num_chunks
|
||||
/// 5. Producer sets chunk_signals[i] = 1 when chunk i's data is ready
|
||||
/// 6. Kernel waits for chunk_signals[chunk_idx] before processing each tile
|
||||
struct PersistentAsyncInputScheduler
|
||||
{
|
||||
/// @brief Number of M-dimension tiles grouped into each chunk.
|
||||
/// Grouping tiles balances synchronization overhead against input streaming granularity.
|
||||
/// Set to 0 to disable async scheduling.
|
||||
uint32_t tiles_per_chunk_m = 0;
|
||||
|
||||
/// @brief Device pointer to array of signal values (uint32_t), one per chunk.
|
||||
/// Producer sets signals to coordinate when input data is ready for consumption.
|
||||
/// Set to nullptr to disable async scheduling.
|
||||
uint32_t* chunk_signals = nullptr;
|
||||
|
||||
/// @brief Pivot offset for rotating the chunk assignment.
|
||||
/// Allows shifting which tiles map to which chunks, useful for load balancing.
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
int32_t tile_idx_pivot_m = 0;
|
||||
|
||||
/// @brief Number of signal chunks allocated.
|
||||
/// Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m).
|
||||
/// Modulo wraparound prevents out-of-bounds access when pivot shifts chunk assignment.
|
||||
uint32_t num_chunks = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
||||
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -30,18 +32,20 @@ namespace ck_tile {
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
|
||||
struct UniversalGemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
CK_TILE_HOST UniversalGemmHostArgs(
|
||||
const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_,
|
||||
PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{})
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
@@ -53,7 +57,8 @@ struct UniversalGemmHostArgs
|
||||
stride_Bs(stride_Bs_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
k_batch(k_batch_),
|
||||
async_input_scheduler(async_input_scheduler_)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -78,6 +83,7 @@ struct UniversalGemmHostArgs
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
PersistentAsyncInputScheduler async_input_scheduler;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
@@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs
|
||||
/// (in memory) of E tensor.
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
/// @brief Persistent async input scheduler for chunk-based tile scheduling.
|
||||
PersistentAsyncInputScheduler async_input_scheduler = {};
|
||||
};
|
||||
|
||||
/// @brief The Universal GEMM kernel template.
|
||||
@@ -201,7 +209,7 @@ struct UniversalGemmKernel
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
// Detect persistent kernel support to select appropriate entry point
|
||||
struct has_persistent_kernel
|
||||
{
|
||||
template <typename T>
|
||||
@@ -216,7 +224,7 @@ struct UniversalGemmKernel
|
||||
};
|
||||
static constexpr bool PersistentKernel = has_persistent_kernel::value;
|
||||
|
||||
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
|
||||
// Detect custom output offset support for advanced partitioning schemes
|
||||
struct has_tile_partitioner_output_offset_impl
|
||||
{
|
||||
template <typename T, typename KernelArgs>
|
||||
@@ -272,10 +280,10 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
* @brief Calculate grid size that maximizes hardware utilization for persistent kernels.
|
||||
* @return Grid size that fills all compute units at maximum occupancy.
|
||||
* @note Persistent kernels loop over tiles, so grid size should match hardware capacity
|
||||
* rather than problem size.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
@@ -315,7 +323,8 @@ struct UniversalGemmKernel
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
hostArgs.k_batch,
|
||||
hostArgs.async_input_scheduler};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -325,11 +334,8 @@ struct UniversalGemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
// This structure distributes work evenly among splitkk workgroups
|
||||
// It's based on a principle that if there is enough work to fill all workgroups,
|
||||
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
|
||||
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
|
||||
// and leave the potential tail for last(splitk - 1) indexed workgroup.
|
||||
// Balances K-dimension work across batches to maximize parallelism while minimizing
|
||||
// load imbalance. Uses ceil division to distribute remainder work evenly.
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
@@ -658,6 +664,28 @@ struct UniversalGemmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify async scheduler parameters to prevent division-by-zero and invalid memory access
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
if(kargs.async_input_scheduler.tiles_per_chunk_m == 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("tiles_per_chunk_m must be positive when chunk_signals is set!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.async_input_scheduler.num_chunks == 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("num_chunks must be positive when chunk_signals is set!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
@@ -1177,12 +1205,30 @@ struct UniversalGemmKernel
|
||||
while(block_id < num_work)
|
||||
{
|
||||
s_waitcnt_barrier();
|
||||
// Get the tile index for this block
|
||||
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Synchronize with producer to ensure input data is ready before processing tile
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tiles_per_chunk =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto num_chunks =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
|
||||
if(tiles_per_chunk > 0 && num_chunks > 0)
|
||||
{
|
||||
// Pivot allows rotating chunk assignments for load balancing
|
||||
const auto chunk_idx = amd_wave_read_first_lane(
|
||||
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
|
||||
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
|
||||
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the SplitK offset for this block
|
||||
const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
add_subdirectory(image_to_column)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_persistent_async_input)
|
||||
add_subdirectory(gemm_weight_preshuffle)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
|
||||
19
test/ck_tile/gemm_persistent_async_input/CMakeLists.txt
Normal file
19
test/ck_tile/gemm_persistent_async_input/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Test for persistent async input GEMM - currently targeting gfx95
|
||||
set(PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS
|
||||
-mllvm
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_persistent_async_input test_gemm_persistent_async_input.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_persistent_async_input PRIVATE ${PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_gemm_persistent_async_input for current target - requires gfx95")
|
||||
endif()
|
||||
@@ -0,0 +1,304 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F16 = ck_tile::fp16_t;
|
||||
using F32 = ck_tile::fp32_t;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave>;
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
class TestGemmPersistentAsyncInput : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
// Use larger M to ensure tiles_m > tile_idx_pivot, exercising the async scheduler
|
||||
static constexpr ck_tile::index_t M = 1536; // 6 tiles with M_Tile=256
|
||||
static constexpr ck_tile::index_t N = 1024;
|
||||
static constexpr ck_tile::index_t K = 512;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
template <bool IsRowMajor>
|
||||
static constexpr ck_tile::index_t get_default_stride(ck_tile::index_t row, ck_tile::index_t col)
|
||||
{
|
||||
if constexpr(IsRowMajor)
|
||||
return col;
|
||||
else
|
||||
return row;
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
|
||||
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
|
||||
|
||||
ck_tile::index_t stride_A = get_default_stride<is_a_row_major>(M, K);
|
||||
ck_tile::index_t stride_B = get_default_stride<is_b_row_major>(K, N);
|
||||
ck_tile::index_t stride_C = get_default_stride<is_c_row_major>(M, N);
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
|
||||
// Fill input tensors with random values
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy input data to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Compute reference result on host
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Setup kernel configuration for persistent async input GEMM
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool kPadM = true;
|
||||
constexpr bool kPadN = true;
|
||||
constexpr bool kPadK = true;
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool StructuredSparsity = false;
|
||||
constexpr bool Persistent = true;
|
||||
constexpr int NumWaveGroup = 1;
|
||||
constexpr bool Preshuffle = false;
|
||||
constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TilePartitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TilePartitionerGroupNum,
|
||||
TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
Persistent,
|
||||
NumWaveGroup,
|
||||
Preshuffle>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
Intrawave::value>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<UniversalGemmProblem>;
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
1, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
false, // TiledMMAPermuteN_
|
||||
1, // BlockedXDLN_PerWarp_
|
||||
DoubleSmemBuffer>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Calculate tiles and chunks for async scheduler.
|
||||
// Uses modulo wraparound like PyTorch - chunk_idx = (iM + pivot) / tiles_per_chunk %
|
||||
// num_chunks
|
||||
constexpr ck_tile::index_t tiles_per_chunk = 2;
|
||||
constexpr ck_tile::index_t tile_idx_pivot = 2;
|
||||
|
||||
const ck_tile::index_t tiles_m = ck_tile::integer_divide_ceil(M, M_Tile);
|
||||
// With add logic, max chunk_idx = (tiles_m - 1 + pivot) / tiles_per_chunk
|
||||
// So num_chunks = ceil((tiles_m + pivot) / tiles_per_chunk)
|
||||
const ck_tile::index_t num_chunks =
|
||||
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
|
||||
|
||||
// Validate async scheduler configuration
|
||||
// With M=1536, M_Tile=256: tiles_m=6, num_chunks=ceil((6+2)/2)=4
|
||||
ASSERT_GT(num_chunks, 0) << "Test requires num_chunks > 0 to exercise async scheduler";
|
||||
ASSERT_GT(tiles_per_chunk, 0) << "tiles_per_chunk must be positive";
|
||||
|
||||
// Allocate chunk signals (initialized to zero)
|
||||
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
|
||||
signal_buf.SetZero();
|
||||
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
|
||||
ASSERT_NE(d_chunk_signals, nullptr) << "Failed to allocate signal buffer";
|
||||
|
||||
// Setup async input scheduler
|
||||
ck_tile::PersistentAsyncInputScheduler async_scheduler;
|
||||
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
|
||||
async_scheduler.chunk_signals = d_chunk_signals;
|
||||
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
|
||||
async_scheduler.num_chunks = num_chunks;
|
||||
|
||||
// Create UniversalGemmHostArgs with async scheduler
|
||||
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({a_m_k_dev_buf.GetDeviceBuffer()},
|
||||
{b_k_n_dev_buf.GetDeviceBuffer()},
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_A},
|
||||
{stride_B},
|
||||
{},
|
||||
stride_C,
|
||||
async_scheduler);
|
||||
|
||||
// Create kernel args using UniversalGemmKernel
|
||||
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
|
||||
|
||||
// Validate kernel args match host configuration
|
||||
ASSERT_EQ(kargs.async_input_scheduler.chunk_signals, d_chunk_signals)
|
||||
<< "Kernel args chunk_signals doesn't match host configuration";
|
||||
ASSERT_EQ(kargs.async_input_scheduler.tiles_per_chunk_m,
|
||||
static_cast<uint32_t>(tiles_per_chunk))
|
||||
<< "Kernel args tiles_per_chunk_m doesn't match host configuration";
|
||||
ASSERT_EQ(kargs.async_input_scheduler.tile_idx_pivot_m,
|
||||
static_cast<int32_t>(tile_idx_pivot))
|
||||
<< "Kernel args tile_idx_pivot_m doesn't match host configuration";
|
||||
|
||||
// Setup grid and blocks for persistent kernel
|
||||
ck_tile::stream_config stream_cfg{nullptr, false};
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
GTEST_SKIP() << "Kernel arguments not supported, skipping test";
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a separate stream for setting signals
|
||||
// Using the same stream would deadlock - memcpy waits for kernel, kernel waits for signal
|
||||
hipStream_t signal_stream;
|
||||
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
|
||||
|
||||
// Launch kernel
|
||||
ck_tile::ignore = ck_tile::launch_kernel(
|
||||
stream_cfg, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// Simulate producer setting chunk signals with interleaved sleep
|
||||
// This simulates async input becoming available over time
|
||||
const int sleep_us = 100; // microseconds between chunks
|
||||
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
|
||||
const uint32_t signal_val = 1;
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
|
||||
&signal_val,
|
||||
sizeof(uint32_t),
|
||||
hipMemcpyHostToDevice,
|
||||
signal_stream));
|
||||
}
|
||||
|
||||
// Wait for all signals to be set
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
|
||||
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
|
||||
|
||||
// Wait for kernel completion
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Copy result back to host
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Validate results
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol = ck_tile::get_relative_threshold<ADataType, CDataType, AccDataType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<ADataType, CDataType, AccDataType>(
|
||||
max_accumulated_value, K);
|
||||
|
||||
bool pass = ck_tile::check_err(
|
||||
c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// Define test types for different layout combinations
|
||||
using RowRowRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Row, Row, Row, F16, F16, F32, F16>;
|
||||
using RowColRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Row, Col, Row, F16, F16, F32, F16>;
|
||||
using ColRowRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Col, Row, Row, F16, F16, F32, F16>;
|
||||
using ColColRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Col, Col, Row, F16, F16, F32, F16>;
|
||||
|
||||
// Test case for Row-Row-Row layout
|
||||
TEST_F(RowRowRow_F16F16F32F16, BasicTest) { this->Run(); }
|
||||
|
||||
// Test case for Row-Col-Row layout
|
||||
TEST_F(RowColRow_F16F16F32F16, BasicTest) { this->Run(); }
|
||||
|
||||
// Test case for Col-Row-Row layout
|
||||
TEST_F(ColRowRow_F16F16F32F16, BasicTest) { this->Run(); }
|
||||
|
||||
// Test case for Col-Col-Row layout
|
||||
TEST_F(ColColRow_F16F16F32F16, BasicTest) { this->Run(); }
|
||||
Reference in New Issue
Block a user