diff --git a/example/ck_tile/42_mhc/CMakeLists.txt b/example/ck_tile/42_mhc/CMakeLists.txt index 4c32b03206..99962ed953 100644 --- a/example/ck_tile/42_mhc/CMakeLists.txt +++ b/example/ck_tile/42_mhc/CMakeLists.txt @@ -13,4 +13,5 @@ add_executable(${TARGET_NAME} mhc_v3_single_block_test.cpp) set(TARGET_NAME example_mhc_v3_two_block_test) add_executable(${TARGET_NAME} mhc_v3_two_block_test.cpp) - +set(TARGET_NAME example_mhc_v3_bf16_benchmark) +add_executable(${TARGET_NAME} mhc_v3_bf16_benchmark.cpp) diff --git a/example/ck_tile/42_mhc/mhc_v3.cpp b/example/ck_tile/42_mhc/mhc_v3.cpp index b5117c8d8e..dd9a33309a 100644 --- a/example/ck_tile/42_mhc/mhc_v3.cpp +++ b/example/ck_tile/42_mhc/mhc_v3.cpp @@ -48,20 +48,15 @@ int main() d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + // Define block shape - 128 threads (2 warps) to match BlockGemmShape configuration using BlockShape = ck_tile::Generic2dBlockShape, ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; - // V3 kernel with 2D tiling - constexpr ck_tile::index_t kMTile = 64; // Batch tile - constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4) - constexpr ck_tile::index_t kKTile = 8; // K tile for C dimension (must match BlockGemmShape::kK) - - using KernelV3 = ck_tile:: - MHCKernelV3; + // V3 kernel - tile sizes automatically derived from Problem::BlockGemmShape + using KernelV3 = ck_tile::MHCKernelV3; const ck_tile::index_t kBlockSize = KernelV3::BlockSize(); diff --git a/example/ck_tile/42_mhc/mhc_v3_bf16_benchmark.cpp b/example/ck_tile/42_mhc/mhc_v3_bf16_benchmark.cpp new file mode 100644 index 0000000000..15be81d9e7 --- /dev/null +++ b/example/ck_tile/42_mhc/mhc_v3_bf16_benchmark.cpp @@ -0,0 +1,215 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/mhc.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/reference/reference_mhc.hpp" +#include "ck_tile/host/check_err.hpp" + +// Parse command-line arguments for MHC benchmark +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("B", "1024", "Batch size") + .insert("n", "4", "Expansion factor (number of streams)") + .insert("C", "4096", "Channels per stream") + .insert("v", "1", "CPU validation (0=no, 1=yes)") + .insert("warmup", "5", "Number of warmup iterations") + .insert("repeat", "20", "Number of benchmark iterations") + .insert("r", "2.0", "Norm scaling factor") + .insert("alpha_pre", "1.5", "Alpha for pre-activation") + .insert("alpha_post", "2.5", "Alpha for post-activation") + .insert("alpha_res", "3.5", "Alpha for residual") + .insert("bias", "1.5", "Bias value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser) +{ + const int B = arg_parser.get_int("B"); + const int n = arg_parser.get_int("n"); + const int C = arg_parser.get_int("C"); + + const int nC = n * C; + const int output_dim = 2 * n + n * n; + + const int do_validation = arg_parser.get_int("v"); + const int warmup = arg_parser.get_int("warmup"); + const int repeat = arg_parser.get_int("repeat"); + + const float r = arg_parser.get_float("r"); + const float alpha_pre = arg_parser.get_float("alpha_pre"); + const float alpha_post = arg_parser.get_float("alpha_post"); + const float alpha_res = arg_parser.get_float("alpha_res"); + const float bias = arg_parser.get_float("bias"); + + std::cout << "\n========================================" << std::endl; + std::cout << "MHC Kernel V3 Benchmark (BF16)" << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size (B): " << B << std::endl; + std::cout << " Expansion factor (n): " << n << std::endl; + std::cout << " Channels per stream (C): " << C << std::endl; + std::cout << " Input dimension (nC): " << nC << std::endl; + std::cout << " Output dimension (2n+n^2): " << output_dim << std::endl; + std::cout << " Data types: X=" << typeid(XDataType).name() + << ", Phi=" << typeid(PhiDataType).name() << ", Y=" << typeid(YDataType).name() + << ", Compute=" << typeid(ComputeDataType).name() << std::endl; + std::cout << " Warmup iterations: " << warmup << std::endl; + std::cout << " Benchmark iterations: " << repeat << std::endl; + std::cout << "========================================" << std::endl; + + // Allocate host tensors + ck_tile::HostTensor h_x({B, nC}); + ck_tile::HostTensor h_phi({nC, output_dim}); + ck_tile::HostTensor h_output({B, output_dim}); + + // Initialize with random data + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + h_output.SetZero(); + + // Allocate device memory + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // Copy data to device + d_x_mem.ToDevice(h_x.data()); + d_phi_mem.ToDevice(h_phi.data()); + d_output_mem.ToDevice(h_output.data()); + + // Define block shape - 128 threads (2 warps) to match BlockGemmShape configuration + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 128>, + ck_tile::sequence<1, 1>>; + + using Problem = ck_tile::MHCProblem; + + // V3 kernel - tile sizes automatically derived from Problem::BlockGemmShape + using KernelV3 = ck_tile::MHCKernelV3; + + const ck_tile::index_t kBlockSize = KernelV3::BlockSize(); + + // 2D grid: (batch / kMTile) × (output_dim / kNTile) + auto grid_size = KernelV3::GetGridSize(B, output_dim); + const ck_tile::index_t kGridSize = + grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{}); + + std::cout << "\nKernel Configuration:" << std::endl; + std::cout << " Grid: " << grid_size.at(ck_tile::number<0>{}) << " × " + << grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" << std::endl; + std::cout << " Block size: " << kBlockSize << " threads" << std::endl; + std::cout << " Shared memory: " << KernelV3::GetSmemSize() << " bytes" << std::endl; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + // Launch kernel with timing + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(KernelV3{}, + kGridSize, + kBlockSize, + KernelV3::GetSmemSize(), + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_phi_mem.GetDeviceBuffer()), + static_cast(d_output_mem.GetDeviceBuffer()), + B, + nC, + output_dim, + n, + r, + alpha_pre, + alpha_post, + alpha_res, + bias)); + + // Calculate performance metrics + std::size_t num_bytes = sizeof(XDataType) * B * nC + // Input x + sizeof(PhiDataType) * nC * output_dim + // Weights phi + sizeof(YDataType) * B * output_dim; // Output + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + // Calculate FLOPs: B * output_dim * (2*nC - 1) for GEMM + additional ops + std::size_t num_flops = static_cast(B) * output_dim * (2 * nC); + float tflops = num_flops / 1.E9 / ave_time; + + std::cout << "\n========================================" << std::endl; + std::cout << "Performance Results:" << std::endl; + std::cout << " Average time: " << ave_time << " ms" << std::endl; + std::cout << " Bandwidth: " << gb_per_sec << " GB/s" << std::endl; + std::cout << " Throughput: " << tflops << " TFLOPS" << std::endl; + std::cout << "========================================" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::cout << "\nRunning validation..." << std::endl; + + d_output_mem.FromDevice(h_output.data()); + + // Compute reference + ck_tile::HostTensor h_output_ref({B, output_dim}); + h_output_ref.SetZero(); + + ck_tile::reference_mhc( + h_x, + h_phi, + h_output_ref, + n, + C, + r, + alpha_pre, + alpha_post, + alpha_res, + bias, + ActivationFunc{}); + + // Validate with appropriate tolerance for bf16 + float rtol = std::is_same_v ? 1e-2f : 1e-3f; + float atol = std::is_same_v ? 1e-2f : 1e-3f; + + pass = ck_tile::check_err( + h_output, h_output_ref, "Error: MHC V3 kernel output mismatch!", rtol, atol); + + std::cout << "Validation: " << (pass ? "PASS" : "FAIL") << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cout << "Failed to parse arguments!" << std::endl; + return -1; + } + + // Run with BF16 inputs, float output and compute + bool pass = run_mhc_benchmark(arg_parser); + + return pass ? 0 : -2; +} diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp index 5a1156ac3d..a4513057d3 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp @@ -22,9 +22,6 @@ namespace ck_tile { template struct MHCKernelV3 { @@ -37,9 +34,10 @@ struct MHCKernelV3 using YDataType = ck_tile::remove_cvref_t; using PhiDataType = ck_tile::remove_cvref_t; - static constexpr index_t kMTile = kMTile_; // Batch tile - static constexpr index_t kNTile = kNTile_; // Output tile - static constexpr index_t kKTile = kKTile_; // K tile for C dimension + // Automatically derive tile sizes from BlockGemmShape (single source of truth!) + static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile + static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile + static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp index 0257acdb69..6096605357 100644 --- a/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp +++ b/include/ck_tile/ops/mhc/pipeline/mhc_problem.hpp @@ -26,12 +26,12 @@ struct MHCProblem using CDataType = ComputeDataType; // Output/accumulator matrix C // BlockGemmShape with kM, kN, kK members for BlockGemm - // Use supported warp gemm configuration for float32: 32x32x8 - // We'll use 2 warps in M and 1 warp in N to get 64x32 block + // Using 32x32x8 warp tiles (supported by MFMA) with 2x1 warp layout for 64x32 block + // This gives better parallelism than 64x32 while using supported warp sizes using BlockGemmShape = - TileGemmShape, // BlockTile (M, N, K) + TileGemmShape, // BlockTile (M, N, K) - keep original for now sequence<2, 1, 1>, // BlockWarps (2 warps in M, 1 in N, 1 in K) - sequence<32, 32, 8>>; // WarpTile (matches available float32 MFMA) + sequence<32, 32, 8>>; // WarpTile (32x32x8 is supported by MFMA) // Layout types for BlockGemm using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC] diff --git a/test/ck_tile/mhc/test_mhc_impl.hpp b/test/ck_tile/mhc/test_mhc_impl.hpp index 0380cc8679..ef43bdfe74 100644 --- a/test/ck_tile/mhc/test_mhc_impl.hpp +++ b/test/ck_tile/mhc/test_mhc_impl.hpp @@ -677,21 +677,15 @@ class TestCkTileMHC : public ::testing::Test d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + // Define block shape - 128 threads (2 warps) to match BlockGemmShape configuration using BlockShape = ck_tile::Generic2dBlockShape, ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; - // V3 kernel with 2D tiling - constexpr ck_tile::index_t kMTile = 64; // Batch tile - constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4) - constexpr ck_tile::index_t kKTile = - 8; // K tile for C dimension (must match BlockGemmShape::kK) - - using KernelV3 = ck_tile:: - MHCKernelV3; + // V3 kernel - tile sizes automatically derived from Problem::BlockGemmShape + using KernelV3 = ck_tile::MHCKernelV3; const ck_tile::index_t kBlockSize = KernelV3::BlockSize();