From 5fe763239373c6682afacbbd87ef8c87d8ce5087 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Thu, 12 Feb 2026 09:24:15 +0000 Subject: [PATCH] Add V5: split-k --- example/ck_tile/42_mhc/CMakeLists.txt | 3 + .../ck_tile/42_mhc/mhc_v5_bf16_benchmark.cpp | 319 ++++++++++++++ include/ck_tile/ops/mhc.hpp | 2 + .../ops/mhc/kernel/mhc_kernel_tile_v5.hpp | 409 ++++++++++++++++++ .../ops/mhc/pipeline/mhc_problem_v5.hpp | 126 ++++++ 5 files changed, 859 insertions(+) create mode 100644 example/ck_tile/42_mhc/mhc_v5_bf16_benchmark.cpp create mode 100644 include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp create mode 100644 include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp diff --git a/example/ck_tile/42_mhc/CMakeLists.txt b/example/ck_tile/42_mhc/CMakeLists.txt index 446e92481c..a2700f51a0 100644 --- a/example/ck_tile/42_mhc/CMakeLists.txt +++ b/example/ck_tile/42_mhc/CMakeLists.txt @@ -18,3 +18,6 @@ add_executable(${TARGET_NAME} mhc_v3_bf16_benchmark.cpp) set(TARGET_NAME example_mhc_v4_bf16_benchmark) add_executable(${TARGET_NAME} mhc_v4_bf16_benchmark.cpp) + +set(TARGET_NAME example_mhc_v5_bf16_benchmark) +add_executable(${TARGET_NAME} mhc_v5_bf16_benchmark.cpp) diff --git a/example/ck_tile/42_mhc/mhc_v5_bf16_benchmark.cpp b/example/ck_tile/42_mhc/mhc_v5_bf16_benchmark.cpp new file mode 100644 index 0000000000..82569118c7 --- /dev/null +++ b/example/ck_tile/42_mhc/mhc_v5_bf16_benchmark.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/mhc.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/reference/reference_mhc.hpp" +#include "ck_tile/host/check_err.hpp" + +// Parse command-line arguments for MHC benchmark +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("B", "1024", "Batch size") + .insert("n", "4", "Expansion factor (number of streams)") + .insert("C", "4096", "Channels per stream") + .insert("v", "1", "CPU validation (0=no, 1=yes)") + .insert("warmup", "5", "Number of warmup iterations") + .insert("repeat", "20", "Number of benchmark iterations") + .insert("r", "2.0", "Norm scaling factor") + .insert("alpha_pre", "1.5", "Alpha for pre-activation") + .insert("alpha_post", "2.5", "Alpha for post-activation") + .insert("alpha_res", "3.5", "Alpha for residual") + .insert("bias", "1.5", "Bias value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template // Template parameter for M tile size +bool run_mhc_benchmark_impl(const ck_tile::ArgParser& arg_parser) +{ + const int B = arg_parser.get_int("B"); + const int n = arg_parser.get_int("n"); + const int C = arg_parser.get_int("C"); + + const int nC = n * C; + const int output_dim = 2 * n + n * n; + + const int do_validation = arg_parser.get_int("v"); + const int warmup = arg_parser.get_int("warmup"); + const int repeat = arg_parser.get_int("repeat"); + + const float r = arg_parser.get_float("r"); + const float alpha_pre = arg_parser.get_float("alpha_pre"); + const float alpha_post = arg_parser.get_float("alpha_post"); + const float alpha_res = arg_parser.get_float("alpha_res"); + const float bias = arg_parser.get_float("bias"); + + std::cout << "\n========================================" << std::endl; + std::cout << "MHC Kernel V5 Benchmark (BF16) - Split-K" << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size (B): " << B << std::endl; + std::cout << " Expansion factor (n): " << n << std::endl; + std::cout << " Channels per stream (C): " << C << std::endl; + std::cout << " Input dimension (nC): " << nC << std::endl; + std::cout << " Output dimension (2n+n^2): " << output_dim << std::endl; + std::cout << " Data types: X=" << typeid(XDataType).name() + << ", Phi=" << typeid(PhiDataType).name() << ", Y=" << typeid(YDataType).name() + << ", Compute=" << typeid(ComputeDataType).name() << std::endl; + std::cout << " Warmup iterations: " << warmup << std::endl; + std::cout << " Benchmark iterations: " << repeat << std::endl; + std::cout << "========================================" << std::endl; + + // Allocate host tensors + ck_tile::HostTensor h_x({B, nC}); + ck_tile::HostTensor h_phi({nC, output_dim}); + ck_tile::HostTensor h_output({B, output_dim}); + + // Initialize with random data + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + h_output.SetZero(); + + // Allocate device memory + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // Copy data to device + d_x_mem.ToDevice(h_x.data()); + d_phi_mem.ToDevice(h_phi.data()); + d_output_mem.ToDevice(h_output.data()); + + using Problem = ck_tile::MHCProblemV5; + + // V5 kernel - split-K implementation with adaptive problem + using KernelV5 = ck_tile::MHCKernelV5; + using ReductionKernel = ck_tile::MHCReductionKernel; + + const ck_tile::index_t kBlockSize = KernelV5::BlockSize(); + + // 2D grid: (batch / kMTile) × (nC / kKTile) + auto grid_size = KernelV5::GetGridSize(B, output_dim, nC); + const ck_tile::index_t grid_m = grid_size.at(ck_tile::number<0>{}); + const ck_tile::index_t grid_k = grid_size.at(ck_tile::number<1>{}); + const ck_tile::index_t kGridSize = grid_m * grid_k; + + std::cout << "\nKernel Configuration:" << std::endl; + std::cout << " Grid: " << grid_m << " × " << grid_k << " = " << kGridSize << " blocks" + << std::endl; + std::cout << " Block size: " << kBlockSize << " threads" << std::endl; + std::cout << " Shared memory: " << KernelV5::GetSmemSize() << " bytes" << std::endl; + std::cout << " Split-K factor: " << grid_k << std::endl; + + // Allocate workspace for split-K partial results + const std::size_t workspace_size = grid_k * B * output_dim * sizeof(ComputeDataType); + const std::size_t partial_norms_size = grid_k * B * sizeof(ComputeDataType); + + ck_tile::DeviceMem d_workspace_mem(workspace_size); + ck_tile::DeviceMem d_partial_norms_mem(partial_norms_size); + + // Initialize workspace to zero + (void)hipMemset(d_workspace_mem.GetDeviceBuffer(), 0, workspace_size); + (void)hipMemset(d_partial_norms_mem.GetDeviceBuffer(), 0, partial_norms_size); + + std::cout << " Workspace size: " << workspace_size / (1024.0 * 1024.0) << " MB" << std::endl; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + // Reduction kernel configuration + const ck_tile::index_t reduction_threads = ReductionKernel::BlockSize(); + const ck_tile::index_t reduction_blocks = + (B * output_dim + reduction_threads - 1) / reduction_threads; + + // Combined kernel launch with timing - warmup and repeat handled by launch_kernel + auto launch_combined = [&]() { + // Launch split-K kernel + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false}, + ck_tile::make_kernel( + KernelV5{}, + kGridSize, + kBlockSize, + KernelV5::GetSmemSize(), + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_phi_mem.GetDeviceBuffer()), + static_cast(d_workspace_mem.GetDeviceBuffer()), + static_cast(d_partial_norms_mem.GetDeviceBuffer()), + B, + nC, + output_dim, + n, + r, + alpha_pre, + alpha_post, + alpha_res, + bias)); + + // Launch reduction kernel + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false}, + ck_tile::make_kernel( + ReductionKernel{}, + reduction_blocks, + reduction_threads, + 0, + static_cast(d_workspace_mem.GetDeviceBuffer()), + static_cast(d_partial_norms_mem.GetDeviceBuffer()), + static_cast(d_output_mem.GetDeviceBuffer()), + B, + nC, + output_dim, + n, + grid_k, + alpha_pre, + alpha_post, + alpha_res, + bias)); + }; + + // Warmup + for(int i = 0; i < warmup; ++i) + { + launch_combined(); + } + + // Benchmark with manual timing + hipEvent_t start, stop; + (void)hipEventCreate(&start); + (void)hipEventCreate(&stop); + + (void)hipEventRecord(start); + for(int i = 0; i < repeat; ++i) + { + launch_combined(); + } + (void)hipEventRecord(stop); + (void)hipEventSynchronize(stop); + + float total_time = 0; + (void)hipEventElapsedTime(&total_time, start, stop); + float ave_time = total_time / repeat; + + (void)hipEventDestroy(start); + (void)hipEventDestroy(stop); + + // Calculate performance metrics + std::size_t num_bytes = sizeof(XDataType) * B * nC + // Input x + sizeof(PhiDataType) * nC * output_dim + // Weights phi + sizeof(YDataType) * B * output_dim; // Output + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + // Calculate FLOPs: B * output_dim * (2*nC - 1) for GEMM + additional ops + std::size_t num_flops = static_cast(B) * output_dim * (2 * nC); + float tflops = num_flops / 1.E9 / ave_time; + + std::cout << "\n========================================" << std::endl; + std::cout << "Performance Results:" << std::endl; + std::cout << " Average time: " << ave_time << " ms" << std::endl; + std::cout << " Bandwidth: " << gb_per_sec << " GB/s" << std::endl; + std::cout << " Throughput: " << tflops << " TFLOPS" << std::endl; + std::cout << "========================================" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::cout << "\nRunning validation..." << std::endl; + + d_output_mem.FromDevice(h_output.data()); + + // Compute reference + ck_tile::HostTensor h_output_ref({B, output_dim}); + h_output_ref.SetZero(); + + ck_tile::reference_mhc( + h_x, + h_phi, + h_output_ref, + n, + C, + r, + alpha_pre, + alpha_post, + alpha_res, + bias, + ActivationFunc{}); + + // Validate with appropriate tolerance for bf16 + float rtol = std::is_same_v ? 1e-2f : 1e-3f; + float atol = std::is_same_v ? 1e-2f : 1e-3f; + + pass = ck_tile::check_err( + h_output, h_output_ref, "Error: MHC V5 kernel output mismatch!", rtol, atol); + + std::cout << "Validation: " << (pass ? "PASS" : "FAIL") << std::endl; + } + + return pass; +} + +// Runtime dispatch wrapper for adaptive tile selection +template +bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser) +{ + const int B = arg_parser.get_int("B"); + + // Adaptive tile selection based on batch size + if(B >= 4096) + { + std::cout << "[Adaptive] Using M=64 tile for large batch (B=" << B << ")" << std::endl; + return run_mhc_benchmark_impl(arg_parser); + } + else + { + std::cout << "[Adaptive] Using M=16 tile for small/medium batch (B=" << B << ")" + << std::endl; + return run_mhc_benchmark_impl(arg_parser); + } +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cout << "Failed to parse arguments!" << std::endl; + return -1; + } + + // Run with BF16 inputs, float output and compute + // Adaptive tile selection happens inside run_mhc_benchmark + bool pass = run_mhc_benchmark(arg_parser); + + return pass ? 0 : -2; +} diff --git a/include/ck_tile/ops/mhc.hpp b/include/ck_tile/ops/mhc.hpp index 90a4a21350..c2e818085d 100644 --- a/include/ck_tile/ops/mhc.hpp +++ b/include/ck_tile/ops/mhc.hpp @@ -7,10 +7,12 @@ #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp" #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp" #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp" +#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp new file mode 100644 index 0000000000..5e8fb1d9ea --- /dev/null +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp @@ -0,0 +1,409 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +// Manifold Constrained Hyper Connection Kernel V5: +// ===================================================================== +// Split-K implementation with 2D grid (B, C): +// - Grid dimension 0: Batch tiles (B / kMTile) +// - Grid dimension 1: C tiles (nC / kKTile) - split-K dimension +// - Each block computes partial GEMM for its C-tile +// - Results stored to workspace buffer (no atomics!) +// - Separate reduction kernel combines partial results + +namespace ck_tile { + +template +struct MHCKernelV5 +{ + using Activation = ck_tile::remove_cvref_t; + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using PhiDataType = ck_tile::remove_cvref_t; + + // Tile sizes from BlockGemmShape + static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile (16) + static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile (32) + static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension (64) + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; } + + // Padding to avoid LDS bank conflicts + // AMD GPUs have 32 LDS banks, 4-byte bank width + // For bf16 (2 bytes), we need padding to avoid stride being multiple of 32 + static constexpr index_t kKTilePadded = kKTile + 8; // Add 8 elements padding + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + // LDS for BlockGemm with padding: A[kMTile, kKTile+8] + B[kNTile, kKTile+8] + constexpr index_t a_lds_size = kMTile * kKTilePadded * sizeof(XDataType); + constexpr index_t b_lds_size = kNTile * kKTilePadded * sizeof(PhiDataType); + return a_lds_size + b_lds_size; + } + + // Grid configuration: 2D grid (B, C) for split-K + CK_TILE_HOST static constexpr auto + GetGridSize(index_t batch, [[maybe_unused]] index_t output_dim, index_t nC) + { + const index_t grid_m = (batch + kMTile - 1) / kMTile; + const index_t grid_k = (nC + kKTile - 1) / kKTile; // Split-K dimension + return make_tuple(grid_m, grid_k); + } + + CK_TILE_DEVICE void operator()(const XDataType* p_x, + const PhiDataType* p_phi, + ComputeDataType* p_workspace, // [grid_k, batch, output_dim] + ComputeDataType* p_partial_norms, // [grid_k, batch] + index_t batch, + index_t nC, + index_t output_dim, + [[maybe_unused]] index_t n, + [[maybe_unused]] float r = 1.0f, + [[maybe_unused]] float alpha_pre = 1.0f, + [[maybe_unused]] float alpha_post = 1.0f, + [[maybe_unused]] float alpha_res = 1.0f, + [[maybe_unused]] float bias = 0.0f) const + { + // 2D block indexing + const index_t grid_m = (batch + kMTile - 1) / kMTile; + const index_t block_m = get_block_id() % grid_m; + const index_t block_k = get_block_id() / grid_m; + + const index_t batch_start = block_m * kMTile; + const index_t k_start = block_k * kKTile; + const index_t out_start = 0; + + if(batch_start >= batch || k_start >= nC) + return; + + // Allocate shared memory with padding + __shared__ char smem_ptr[GetSmemSize()]; + XDataType* x_lds = reinterpret_cast(smem_ptr); + PhiDataType* phi_lds = + reinterpret_cast(smem_ptr + kMTile * kKTilePadded * sizeof(XDataType)); + + // Create BlockGemm instance and result tile + using BlockGemm = BlockGemmASmemBSmemCRegV1; + auto result_tile = BlockGemm::MakeCBlockTile(); + set_tile(result_tile, 0.0f); + + // Determine actual K size for this block + const index_t k_size = ck_tile::min(kKTile, nC - k_start); + + // Create tensor views for X and Phi + auto x_tensor_full = make_naive_tensor_view( + p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{}); + + auto x_tensor_padded = pad_tensor_view(x_tensor_full, + make_tuple(number{}, number{}), + sequence{}); + + constexpr auto x_load_tile_dist = Problem::MakeXLoadTileDistribution(); + auto x_dram_window = make_tile_window(x_tensor_padded, + make_tuple(number{}, number{}), + {batch_start, k_start}, + x_load_tile_dist); + + auto x_lds_tensor = make_naive_tensor_view( + x_lds, + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + auto x_lds_window = + make_tile_window(x_lds_tensor, make_tuple(number{}, number{}), {0, 0}); + + // Create Phi tensor view and window + auto phi_tensor_full = make_naive_tensor_view( + p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{}); + + auto phi_tensor_padded = pad_tensor_view(phi_tensor_full, + make_tuple(number{}, number{}), + sequence{}); + + constexpr auto phi_load_tile_dist = Problem::MakePhiLoadTileDistribution(); + auto phi_dram_window = make_tile_window(phi_tensor_padded, + make_tuple(number{}, number{}), + {out_start, k_start}, + phi_load_tile_dist); + + auto phi_lds_tensor = make_naive_tensor_view( + phi_lds, + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + auto phi_lds_window = make_tile_window( + phi_lds_tensor, make_tuple(number{}, number{}), {0, 0}); + + // Compute partial norms for this K-tile + const index_t thread_id = get_thread_id(); + constexpr index_t threads_per_row = kBlockSize / kMTile; + const index_t row_id = thread_id / threads_per_row; + const index_t thread_in_row = thread_id % threads_per_row; + + __shared__ ComputeDataType norm_reduction[kMTile][threads_per_row]; + + if(row_id < kMTile) + { + const index_t global_m = batch_start + row_id; + ComputeDataType partial_sum = 0.0f; + + if(global_m < batch) + { + const XDataType* row_ptr = p_x + global_m * nC + k_start; + + constexpr index_t kVecSize = 4; + for(index_t k = thread_in_row * kVecSize; k < k_size; + k += threads_per_row * kVecSize) + { + if(k + kVecSize <= k_size) + { + using VecType = ext_vector_t; + VecType vec = *c_style_pointer_cast(row_ptr + k); + +#pragma unroll + for(index_t i = 0; i < kVecSize; ++i) + { + ComputeDataType val = type_convert(vec[i]); + partial_sum += val * val; + } + } + else + { + for(index_t i = 0; i < kVecSize && k + i < k_size; ++i) + { + ComputeDataType val = type_convert(row_ptr[k + i]); + partial_sum += val * val; + } + } + } + } + + norm_reduction[row_id][thread_in_row] = partial_sum; + } + + block_sync_lds(); + + // Reduce and store partial norms to global memory + if(thread_in_row == 0 && row_id < kMTile) + { + const index_t global_m = batch_start + row_id; + + if(global_m < batch) + { + ComputeDataType sum_squares = 0.0f; +#pragma unroll + for(index_t t = 0; t < threads_per_row; ++t) + { + sum_squares += norm_reduction[row_id][t]; + } + + // Store to global memory: p_partial_norms[block_k, global_m] + p_partial_norms[block_k * batch + global_m] = sum_squares; + } + } + + // Load X tile for this K-slice + auto x_tile = make_static_distributed_tensor(x_load_tile_dist); + load_tile(x_tile, x_dram_window); + store_tile(x_lds_window, x_tile); + + // Load Phi tile for this K-slice + auto phi_tile = make_static_distributed_tensor(phi_load_tile_dist); + load_tile(phi_tile, phi_dram_window); + store_tile(phi_lds_window, phi_tile); + + block_sync_lds(); + + // Perform GEMM for this K-slice: result_tile = x_lds * phi_lds^T + // Note: This is a partial result for just this K-tile + BlockGemm{}(result_tile, x_lds_window, phi_lds_window); + + block_sync_lds(); + + // Store partial results to workspace buffer: p_workspace[block_k, batch, output_dim] + // Layout: [grid_k][batch][output_dim] + constexpr auto result_spans = decltype(result_tile)::get_distributed_spans(); + sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + result_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const index_t local_m = tile_idx.at(number<0>{}); + const index_t local_n = tile_idx.at(number<1>{}); + const index_t global_m = batch_start + local_m; + const index_t global_n = out_start + local_n; + + if(global_m < batch && global_n < output_dim) + { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + ComputeDataType value = result_tile[i_j_idx]; + + // Store to workspace: [block_k][global_m][global_n] + const index_t workspace_idx = + block_k * (batch * output_dim) + global_m * output_dim + global_n; + p_workspace[workspace_idx] = value; + } + }); + }); + } +}; + +// Optimized reduction kernel with block-level shared memory reduction +template +struct MHCReductionKernel +{ + using Activation = ck_tile::remove_cvref_t; + using Problem = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + static constexpr index_t kVecSize = 4; // Vectorized loads + + CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; } + + CK_TILE_DEVICE void operator()(const ComputeDataType* p_workspace, + const ComputeDataType* p_partial_norms, + YDataType* p_output, + index_t batch, + index_t nC, + index_t output_dim, + index_t n, + index_t grid_k, + float alpha_pre, + float alpha_post, + float alpha_res, + float bias) const + { + const index_t tid = get_thread_id(); + const index_t block_id = get_block_id(); + const index_t block_size = get_block_size(); + + // Each block processes multiple output elements + // Use block-level reduction for better memory coalescing + const index_t elements_per_block = block_size; + const index_t global_start = block_id * elements_per_block; + const index_t total_elements = batch * output_dim; + + const index_t global_idx = global_start + tid; + + if(global_idx >= total_elements) + return; + + const index_t global_m = global_idx / output_dim; + const index_t global_n = global_idx % output_dim; + + // Reduce partial norms with vectorized loads where possible + ComputeDataType sum_squares = 0.0f; + const index_t norm_base = global_m; + + // Vectorized reduction for norms + index_t k = 0; + for(; k + kVecSize <= grid_k; k += kVecSize) + { + using VecType = ext_vector_t; + VecType vec_norms; + +#pragma unroll + for(index_t i = 0; i < kVecSize; ++i) + { + vec_norms[i] = p_partial_norms[(k + i) * batch + norm_base]; + } + +#pragma unroll + for(index_t i = 0; i < kVecSize; ++i) + { + sum_squares += vec_norms[i]; + } + } + + // Handle remaining elements + for(; k < grid_k; ++k) + { + sum_squares += p_partial_norms[k * batch + norm_base]; + } + + const ComputeDataType sqrt_nC = ck_tile::sqrt(static_cast(nC)); + ComputeDataType norm = ck_tile::sqrt(sum_squares) / sqrt_nC; + norm = (norm > 1e-12f) ? norm : 1.0f; + + // Reduce partial GEMM results with improved memory access pattern + // Reorganize to improve coalescing: threads in a warp access consecutive elements + ComputeDataType value = 0.0f; + const index_t workspace_stride = batch * output_dim; + + // Vectorized reduction for workspace + k = 0; + for(; k + kVecSize <= grid_k; k += kVecSize) + { + using VecType = ext_vector_t; + VecType vec_values; + +#pragma unroll + for(index_t i = 0; i < kVecSize; ++i) + { + const index_t workspace_idx = (k + i) * workspace_stride + global_idx; + vec_values[i] = p_workspace[workspace_idx]; + } + +#pragma unroll + for(index_t i = 0; i < kVecSize; ++i) + { + value += vec_values[i]; + } + } + + // Handle remaining elements + for(; k < grid_k; ++k) + { + const index_t workspace_idx = k * workspace_stride + global_idx; + value += p_workspace[workspace_idx]; + } + + // Apply normalization and activation based on output section + ComputeDataType final_value; + if(global_n < n) + { + // Pre-activation section + ComputeDataType activated_value; + Activation{}(activated_value, value); + final_value = (alpha_pre / norm) * activated_value + bias; + } + else if(global_n < 2 * n) + { + // Post-activation section + ComputeDataType activated_value; + Activation{}(activated_value, value); + final_value = (alpha_post / norm) * 2.0f * activated_value + bias; + } + else + { + // Residual section + final_value = (alpha_res / norm) * value + bias; + } + + p_output[global_idx] = type_convert(final_value); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp new file mode 100644 index 0000000000..e12b897f57 --- /dev/null +++ b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp @@ -0,0 +1,126 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp" + +namespace ck_tile { + +// MHC Problem V5: Optimized for large C values with split-K +// Adaptive M tile size based on batch size for optimal performance +template // Default M=16 for small/medium batches +struct MHCProblemV5 +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + + using PhiDataType = XDataType; + + // BlockGemm compatibility + using ADataType = XDataType; + using BDataType = PhiDataType; + using CDataType = ComputeDataType; + + static constexpr index_t kMTile = MTile_; // Adaptive M tile size + + // Adaptive tile configuration + // M=16 (default): Optimal for small/medium batches (B < 4096) + // M=64: Optimal for large batches (B >= 4096) + // N=32, K=128: Fixed for all configurations + using BlockGemmShape = TileGemmShape, // BlockTile: Adaptive M + sequence<1, 1, 1>, // BlockWarps: 1 warp + sequence>; // WarpTile: matches BlockTile + + static constexpr index_t VectorSizeA = 4; + static constexpr index_t VectorSizeB = 4; + + // 1 warp × 64 threads/warp = 64 threads (same as V4) + using BlockShape = Generic2dBlockShape, sequence<1, 64>, sequence<1, 1>>; + + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + using AsDataTypeTuple = tuple; + using BsDataTypeTuple = tuple; + using AsLayoutTuple = tuple; + using BsLayoutTuple = tuple; + + using AElementWise = identity; + using BElementWise = identity; + + static constexpr bool TransposeC = false; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool Preshuffle = false; + + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; + static constexpr index_t NumWaveGroups = 1; + + static constexpr index_t VectorLoadSize = 16; + static constexpr index_t kBlockSize = BlockShape::BlockSize; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool FixedVectorSize = false; + + struct Traits + { + static constexpr bool UsePersistentKernel = false; + }; + + CK_TILE_HOST static const std::string GetName() { return "MHCProblemV5"; } + + // Adaptive tile distribution for loading X (input matrix) + // X is [Batch, nC] row-major, we load kM×kK tiles + // For M=16: H0 (M): [grid=1, warp=1, thread=16, vector=1] = 16 + // For M=64: H0 (M): [grid=4, warp=1, thread=16, vector=1] = 64 + // H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128 (same for all) + CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution() + { + using namespace ck_tile; + + constexpr index_t m_grid = MTile_ / 16; // M=16 → grid=1, M=64 → grid=4 + + using XTileDistEncoding = tile_distribution_encoding< + sequence<>, // R: No replication + tuple, // H0 (M): adaptive grid based on MTile_ + sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16 + tuple, sequence<1, 2>>, // P→RH major: warp arrangement + tuple, sequence<2, 2>>, // P→RH minor: thread arrangement + sequence<1, 1, 2, 2>, // Y→RH major: data layout + sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization + + return make_static_tile_distribution(XTileDistEncoding{}); + } + + // Tile distribution for loading Phi (weight matrix) + // Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×128) + // H0 (N): [grid=1, warp=1, thread=16, vector=2] = 32 + // H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128 + CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution() + { + using namespace ck_tile; + + using PhiTileDistEncoding = tile_distribution_encoding< + sequence<>, // R: No replication + tuple, // H0 (N): grid=1, warp=1, thread=16, vector=2 + sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16 + tuple, sequence<1, 2>>, // P→RH major: warp arrangement + tuple, sequence<2, 2>>, // P→RH minor: thread arrangement + sequence<1, 1, 2, 2>, // Y→RH major: data layout + sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization + + return make_static_tile_distribution(PhiTileDistEncoding{}); + } +}; + +} // namespace ck_tile