diff --git a/example/ck_tile/42_mhc/CMakeLists.txt b/example/ck_tile/42_mhc/CMakeLists.txt index 99962ed953..446e92481c 100644 --- a/example/ck_tile/42_mhc/CMakeLists.txt +++ b/example/ck_tile/42_mhc/CMakeLists.txt @@ -15,3 +15,6 @@ add_executable(${TARGET_NAME} mhc_v3_two_block_test.cpp) set(TARGET_NAME example_mhc_v3_bf16_benchmark) add_executable(${TARGET_NAME} mhc_v3_bf16_benchmark.cpp) + +set(TARGET_NAME example_mhc_v4_bf16_benchmark) +add_executable(${TARGET_NAME} mhc_v4_bf16_benchmark.cpp) diff --git a/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp b/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp new file mode 100644 index 0000000000..53eec3a491 --- /dev/null +++ b/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp @@ -0,0 +1,216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/mhc.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/reference/reference_mhc.hpp" +#include "ck_tile/host/check_err.hpp" + +// Parse command-line arguments for MHC benchmark +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("B", "1024", "Batch size") + .insert("n", "4", "Expansion factor (number of streams)") + .insert("C", "4096", "Channels per stream") + .insert("v", "1", "CPU validation (0=no, 1=yes)") + .insert("warmup", "5", "Number of warmup iterations") + .insert("repeat", "20", "Number of benchmark iterations") + .insert("r", "2.0", "Norm scaling factor") + .insert("alpha_pre", "1.5", "Alpha for pre-activation") + .insert("alpha_post", "2.5", "Alpha for post-activation") + .insert("alpha_res", "3.5", "Alpha for residual") + .insert("bias", "1.5", "Bias value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser) +{ + const int B = arg_parser.get_int("B"); + const int n = arg_parser.get_int("n"); + const int C = arg_parser.get_int("C"); + + const int nC = n * C; + const int output_dim = 2 * n + n * n; + + const int do_validation = arg_parser.get_int("v"); + const int warmup = arg_parser.get_int("warmup"); + const int repeat = arg_parser.get_int("repeat"); + + const float r = arg_parser.get_float("r"); + const float alpha_pre = arg_parser.get_float("alpha_pre"); + const float alpha_post = arg_parser.get_float("alpha_post"); + const float alpha_res = arg_parser.get_float("alpha_res"); + const float bias = arg_parser.get_float("bias"); + + std::cout << "\n========================================" << std::endl; + std::cout << "MHC Kernel V4 Benchmark (BF16)" << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size (B): " << B << std::endl; + std::cout << " Expansion factor (n): " << n << std::endl; + std::cout << " Channels per stream (C): " << C << std::endl; + std::cout << " Input dimension (nC): " << nC << std::endl; + std::cout << " Output dimension (2n+n^2): " << output_dim << std::endl; + std::cout << " Data types: X=" << typeid(XDataType).name() + << ", Phi=" << typeid(PhiDataType).name() << ", Y=" << typeid(YDataType).name() + << ", Compute=" << typeid(ComputeDataType).name() << std::endl; + std::cout << " Warmup iterations: " << warmup << std::endl; + std::cout << " Benchmark iterations: " << repeat << std::endl; + std::cout << "========================================" << std::endl; + + // Allocate host tensors + ck_tile::HostTensor h_x({B, nC}); + ck_tile::HostTensor h_phi({nC, output_dim}); + ck_tile::HostTensor h_output({B, output_dim}); + + // Initialize with random data + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + h_output.SetZero(); + + // Allocate device memory + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // Copy data to device + d_x_mem.ToDevice(h_x.data()); + d_phi_mem.ToDevice(h_phi.data()); + d_output_mem.ToDevice(h_output.data()); + + // Define block shape - 64 threads (1 warp) to match BlockGemmShape configuration + // This matches a 16x16 block tile with 1 warp (1x1 warp layout) + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 64>, + ck_tile::sequence<1, 1>>; + + using Problem = ck_tile::MHCProblem; + + // V4 kernel - optimized with single-pass data loading + using KernelV4 = ck_tile::MHCKernelV4; + + const ck_tile::index_t kBlockSize = KernelV4::BlockSize(); + + // 2D grid: (batch / kMTile) × (output_dim / kNTile) + auto grid_size = KernelV4::GetGridSize(B, output_dim); + const ck_tile::index_t kGridSize = + grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{}); + + std::cout << "\nKernel Configuration:" << std::endl; + std::cout << " Grid: " << grid_size.at(ck_tile::number<0>{}) << " × " + << grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" << std::endl; + std::cout << " Block size: " << kBlockSize << " threads" << std::endl; + std::cout << " Shared memory: " << KernelV4::GetSmemSize() << " bytes" << std::endl; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + // Launch kernel with timing + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(KernelV4{}, + kGridSize, + kBlockSize, + KernelV4::GetSmemSize(), + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_phi_mem.GetDeviceBuffer()), + static_cast(d_output_mem.GetDeviceBuffer()), + B, + nC, + output_dim, + n, + r, + alpha_pre, + alpha_post, + alpha_res, + bias)); + + // Calculate performance metrics + std::size_t num_bytes = sizeof(XDataType) * B * nC + // Input x + sizeof(PhiDataType) * nC * output_dim + // Weights phi + sizeof(YDataType) * B * output_dim; // Output + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + // Calculate FLOPs: B * output_dim * (2*nC - 1) for GEMM + additional ops + std::size_t num_flops = static_cast(B) * output_dim * (2 * nC); + float tflops = num_flops / 1.E9 / ave_time; + + std::cout << "\n========================================" << std::endl; + std::cout << "Performance Results:" << std::endl; + std::cout << " Average time: " << ave_time << " ms" << std::endl; + std::cout << " Bandwidth: " << gb_per_sec << " GB/s" << std::endl; + std::cout << " Throughput: " << tflops << " TFLOPS" << std::endl; + std::cout << "========================================" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::cout << "\nRunning validation..." << std::endl; + + d_output_mem.FromDevice(h_output.data()); + + // Compute reference + ck_tile::HostTensor h_output_ref({B, output_dim}); + h_output_ref.SetZero(); + + ck_tile::reference_mhc( + h_x, + h_phi, + h_output_ref, + n, + C, + r, + alpha_pre, + alpha_post, + alpha_res, + bias, + ActivationFunc{}); + + // Validate with appropriate tolerance for bf16 + float rtol = std::is_same_v ? 1e-2f : 1e-3f; + float atol = std::is_same_v ? 1e-2f : 1e-3f; + + pass = ck_tile::check_err( + h_output, h_output_ref, "Error: MHC V4 kernel output mismatch!", rtol, atol); + + std::cout << "Validation: " << (pass ? "PASS" : "FAIL") << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cout << "Failed to parse arguments!" << std::endl; + return -1; + } + + // Run with BF16 inputs, float output and compute + bool pass = run_mhc_benchmark(arg_parser); + + return pass ? 0 : -2; +} diff --git a/include/ck_tile/ops/mhc.hpp b/include/ck_tile/ops/mhc.hpp index c336a5de97..690beaafae 100644 --- a/include/ck_tile/ops/mhc.hpp +++ b/include/ck_tile/ops/mhc.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile.hpp" #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp" #include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp" +#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp index 05886c5e09..e37cce73c0 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp @@ -181,11 +181,11 @@ struct MHCKernelV3 auto phi_identity_func = [](auto& e, const PhiDataType& phi_val) { e = phi_val; }; auto result_tile = gemm_pipeline(make_tuple(x_dram_window), - phi_identity_func, - make_tuple(phi_dram_window), - phi_identity_func, - num_k_loops, - smem); + phi_identity_func, + make_tuple(phi_dram_window), + phi_identity_func, + num_k_loops, + smem); // Apply normalization and activation in post-processing // Now we divide by norm AFTER the GEMM, which means: @@ -222,7 +222,8 @@ struct MHCKernelV3 { ComputeDataType activated_value; Activation{}(activated_value, value); - result_tile(i_j_idx) = alpha_post * inv_norm * 2.0f * activated_value + bias; + result_tile(i_j_idx) = + alpha_post * inv_norm * 2.0f * activated_value + bias; } else { diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp new file mode 100644 index 0000000000..a1ea5087a2 --- /dev/null +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp @@ -0,0 +1,255 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +// Manifold Constrained Hyper Connection Kernel V4: +// ===================================================================== +// Optimizations implemented: +// - Remove GEMM pipeline to avoid redundant global memory reads +// - Use BlockGemm directly with manual LDS management (like v2) +// - Compute normalization incrementally during GEMM loop +// - Single pass through input data: load once, compute norm and GEMM together + +namespace ck_tile { + +template +struct MHCKernelV4 +{ + using Activation = ck_tile::remove_cvref_t; + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using PhiDataType = ck_tile::remove_cvref_t; + + // Automatically derive tile sizes from BlockGemmShape (single source of truth!) + static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile + static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile + static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + // LDS for BlockGemm: A[kMTile, kKTile] + B[kKTile, kNTile] + constexpr index_t a_lds_size = kMTile * kKTile * sizeof(XDataType); + constexpr index_t b_lds_size = kKTile * kNTile * sizeof(PhiDataType); + return a_lds_size + b_lds_size; + } + + // Grid configuration: 2D grid over (batch, output_dim) + CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim) + { + const index_t grid_m = (batch + kMTile - 1) / kMTile; + const index_t grid_n = (output_dim + kNTile - 1) / kNTile; + return make_tuple(grid_m, grid_n); + } + + CK_TILE_DEVICE void operator()(const XDataType* p_x, + const PhiDataType* p_phi, + YDataType* p_output, + index_t batch, + index_t nC, + index_t output_dim, + [[maybe_unused]] index_t n, + [[maybe_unused]] float r = 1.0f, + [[maybe_unused]] float alpha_pre = 1.0f, + [[maybe_unused]] float alpha_post = 1.0f, + [[maybe_unused]] float alpha_res = 1.0f, + [[maybe_unused]] float bias = 0.0f) const + { + // 2D block indexing + const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile; + const index_t block_id = get_block_id(); + const index_t block_m = block_id / grid_n_size; + const index_t block_n = block_id % grid_n_size; + + const index_t batch_start = block_m * kMTile; + const index_t out_start = block_n * kNTile; + + if(batch_start >= batch || out_start >= output_dim) + return; + + const index_t tid = get_thread_id(); + + // Allocate shared memory for A and B tiles + norm accumulators + GEMM results + __shared__ char smem_ptr[GetSmemSize()]; + XDataType* x_lds = reinterpret_cast(smem_ptr); + PhiDataType* phi_lds = + reinterpret_cast(smem_ptr + kMTile * kKTile * sizeof(XDataType)); + + // Shared memory for norm accumulation (one per batch element in tile) + __shared__ ComputeDataType sum_squares_shared[kMTile]; + + // Shared memory for GEMM result accumulation + __shared__ ComputeDataType result_shared[kMTile * kNTile]; + + // Initialize shared norm accumulators and result + for(index_t i = tid; i < kMTile; i += get_block_size()) + { + sum_squares_shared[i] = 0.0f; + } + for(index_t i = tid; i < kMTile * kNTile; i += get_block_size()) + { + result_shared[i] = 0.0f; + } + block_sync_lds(); + + // Number of K-tile iterations + const index_t num_k_tiles = (nC + kKTile - 1) / kKTile; + + // Main loop: load tiles, compute norms incrementally, and accumulate GEMM + for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx) + { + const index_t k_start = k_tile_idx * kKTile; + const index_t k_end = min(k_start + kKTile, nC); + const index_t k_len = k_end - k_start; + + // Load X tile from global to LDS and accumulate norm + for(index_t i = tid; i < kMTile * kKTile; i += get_block_size()) + { + const index_t local_m = i / kKTile; + const index_t local_k = i % kKTile; + const index_t global_m = batch_start + local_m; + const index_t global_k = k_start + local_k; + + XDataType x_val = 0; + if(global_m < batch && local_k < k_len) + { + x_val = p_x[global_m * nC + global_k]; + + // Accumulate norm for this batch element using atomics + ComputeDataType x_compute = type_convert(x_val); + ComputeDataType sq = x_compute * x_compute; + atomicAdd(&sum_squares_shared[local_m], sq); + } + x_lds[i] = x_val; + } + + // Load Phi tile from global to LDS in column-major format (K x N) + // phi is stored in global memory as [nC, output_dim] row-major + // We need to transpose it to [K, N] column-major for BlockGemm + for(index_t i = tid; i < kKTile * kNTile; i += get_block_size()) + { + const index_t local_k = i / kNTile; + const index_t local_n = i % kNTile; + const index_t global_k = k_start + local_k; + const index_t global_n = out_start + local_n; + + PhiDataType phi_val = 0; + if(local_k < k_len && global_n < output_dim) + { + phi_val = p_phi[global_k * output_dim + global_n]; + } + // Store in column-major: phi_lds[n * kKTile + k] + phi_lds[local_n * kKTile + local_k] = phi_val; + } + + block_sync_lds(); + + // Perform manual GEMM: result_acc += x_lds * phi_lds^T + // Distribute work: each thread computes a subset of output elements + // With 64 threads and 16x16 output, each thread handles 4 elements + const index_t total_elements = kMTile * kNTile; + const index_t elements_per_thread = + (total_elements + get_block_size() - 1) / get_block_size(); + + for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx) + { + const index_t global_elem = tid * elements_per_thread + elem_idx; + if(global_elem < total_elements) + { + const index_t m_idx = global_elem / kNTile; + const index_t n_idx = global_elem % kNTile; + + ComputeDataType acc = 0.0f; + for(index_t k_idx = 0; k_idx < kKTile; ++k_idx) + { + ComputeDataType x_val = + type_convert(x_lds[m_idx * kKTile + k_idx]); + ComputeDataType phi_val = + type_convert(phi_lds[n_idx * kKTile + k_idx]); + acc += x_val * phi_val; + } + // Accumulate to shared memory using atomics + atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc); + } + } + + block_sync_lds(); + } + + // Ensure all norm accumulations are complete + block_sync_lds(); + + // Compute inverse norms after all K-tiles processed + ComputeDataType inv_norms[kMTile]; + for(index_t local_m = 0; local_m < kMTile; ++local_m) + { + const index_t global_m = batch_start + local_m; + if(global_m < batch) + { + const ComputeDataType norm = ck_tile::sqrt(sum_squares_shared[local_m]) / + ck_tile::sqrt(static_cast(nC)); + inv_norms[local_m] = 1.0f / norm; + } + else + { + inv_norms[local_m] = 1.0f; + } + } + + // Apply normalization, activation, and write output + for(index_t i = tid; i < kMTile * kNTile; i += get_block_size()) + { + const index_t local_m = i / kNTile; + const index_t local_n = i % kNTile; + + const index_t global_m = batch_start + local_m; + const index_t global_n = out_start + local_n; + + if(global_m >= batch || global_n >= output_dim) + continue; + + const ComputeDataType inv_norm = inv_norms[local_m]; + ComputeDataType value = result_shared[i]; + + // Apply normalization and activation based on output section + if(global_n < n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + value = alpha_pre * inv_norm * activated_value + bias; + } + else if(global_n < 2 * n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + value = alpha_post * inv_norm * 2.0f * activated_value + bias; + } + else + { + value = alpha_res * inv_norm * value + bias; + } + + // Write to global memory + p_output[global_m * output_dim + global_n] = type_convert(value); + } + } +}; + +} // namespace ck_tile