Add V4: remove gemm pipeline, combine gemm/normalization

This commit is contained in:
Damien Lejeune
2026-02-10 10:39:49 +00:00
parent 6c45f722e7
commit 7c728adb57
5 changed files with 482 additions and 6 deletions

View File

@@ -15,3 +15,6 @@ add_executable(${TARGET_NAME} mhc_v3_two_block_test.cpp)
set(TARGET_NAME example_mhc_v3_bf16_benchmark)
add_executable(${TARGET_NAME} mhc_v3_bf16_benchmark.cpp)
set(TARGET_NAME example_mhc_v4_bf16_benchmark)
add_executable(${TARGET_NAME} mhc_v4_bf16_benchmark.cpp)

View File

@@ -0,0 +1,216 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/mhc.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
// Parse command-line arguments for MHC benchmark
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("B", "1024", "Batch size")
.insert("n", "4", "Expansion factor (number of streams)")
.insert("C", "4096", "Channels per stream")
.insert("v", "1", "CPU validation (0=no, 1=yes)")
.insert("warmup", "5", "Number of warmup iterations")
.insert("repeat", "20", "Number of benchmark iterations")
.insert("r", "2.0", "Norm scaling factor")
.insert("alpha_pre", "1.5", "Alpha for pre-activation")
.insert("alpha_post", "2.5", "Alpha for post-activation")
.insert("alpha_res", "3.5", "Alpha for residual")
.insert("bias", "1.5", "Bias value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename XDataType,
typename PhiDataType,
typename YDataType,
typename ComputeDataType,
typename ActivationFunc = ck_tile::element_wise::Sigmoid>
bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser)
{
const int B = arg_parser.get_int("B");
const int n = arg_parser.get_int("n");
const int C = arg_parser.get_int("C");
const int nC = n * C;
const int output_dim = 2 * n + n * n;
const int do_validation = arg_parser.get_int("v");
const int warmup = arg_parser.get_int("warmup");
const int repeat = arg_parser.get_int("repeat");
const float r = arg_parser.get_float("r");
const float alpha_pre = arg_parser.get_float("alpha_pre");
const float alpha_post = arg_parser.get_float("alpha_post");
const float alpha_res = arg_parser.get_float("alpha_res");
const float bias = arg_parser.get_float("bias");
std::cout << "\n========================================" << std::endl;
std::cout << "MHC Kernel V4 Benchmark (BF16)" << std::endl;
std::cout << "========================================" << std::endl;
std::cout << "Configuration:" << std::endl;
std::cout << " Batch size (B): " << B << std::endl;
std::cout << " Expansion factor (n): " << n << std::endl;
std::cout << " Channels per stream (C): " << C << std::endl;
std::cout << " Input dimension (nC): " << nC << std::endl;
std::cout << " Output dimension (2n+n^2): " << output_dim << std::endl;
std::cout << " Data types: X=" << typeid(XDataType).name()
<< ", Phi=" << typeid(PhiDataType).name() << ", Y=" << typeid(YDataType).name()
<< ", Compute=" << typeid(ComputeDataType).name() << std::endl;
std::cout << " Warmup iterations: " << warmup << std::endl;
std::cout << " Benchmark iterations: " << repeat << std::endl;
std::cout << "========================================" << std::endl;
// Allocate host tensors
ck_tile::HostTensor<XDataType> h_x({B, nC});
ck_tile::HostTensor<PhiDataType> h_phi({nC, output_dim});
ck_tile::HostTensor<YDataType> h_output({B, output_dim});
// Initialize with random data
ck_tile::FillUniformDistribution<XDataType>{-1.0f, 1.0f}(h_x);
ck_tile::FillUniformDistribution<PhiDataType>{-0.5f, 0.5f}(h_phi);
h_output.SetZero();
// Allocate device memory
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
// Copy data to device
d_x_mem.ToDevice(h_x.data());
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape - 64 threads (1 warp) to match BlockGemmShape configuration
// This matches a 16x16 block tile with 1 warp (1x1 warp layout)
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 64>,
ck_tile::sequence<1, 64>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<XDataType, ComputeDataType, YDataType, BlockShape>;
// V4 kernel - optimized with single-pass data loading
using KernelV4 = ck_tile::MHCKernelV4<Problem, ck_tile::MHCDefaultPolicy, ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelV4::BlockSize();
// 2D grid: (batch / kMTile) × (output_dim / kNTile)
auto grid_size = KernelV4::GetGridSize(B, output_dim);
const ck_tile::index_t kGridSize =
grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{});
std::cout << "\nKernel Configuration:" << std::endl;
std::cout << " Grid: " << grid_size.at(ck_tile::number<0>{}) << " × "
<< grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" << std::endl;
std::cout << " Block size: " << kBlockSize << " threads" << std::endl;
std::cout << " Shared memory: " << KernelV4::GetSmemSize() << " bytes" << std::endl;
constexpr ck_tile::index_t kBlockPerCu = 1;
// Launch kernel with timing
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(KernelV4{},
kGridSize,
kBlockSize,
KernelV4::GetSmemSize(),
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
static_cast<PhiDataType*>(d_phi_mem.GetDeviceBuffer()),
static_cast<YDataType*>(d_output_mem.GetDeviceBuffer()),
B,
nC,
output_dim,
n,
r,
alpha_pre,
alpha_post,
alpha_res,
bias));
// Calculate performance metrics
std::size_t num_bytes = sizeof(XDataType) * B * nC + // Input x
sizeof(PhiDataType) * nC * output_dim + // Weights phi
sizeof(YDataType) * B * output_dim; // Output
float gb_per_sec = num_bytes / 1.E6 / ave_time;
// Calculate FLOPs: B * output_dim * (2*nC - 1) for GEMM + additional ops
std::size_t num_flops = static_cast<std::size_t>(B) * output_dim * (2 * nC);
float tflops = num_flops / 1.E9 / ave_time;
std::cout << "\n========================================" << std::endl;
std::cout << "Performance Results:" << std::endl;
std::cout << " Average time: " << ave_time << " ms" << std::endl;
std::cout << " Bandwidth: " << gb_per_sec << " GB/s" << std::endl;
std::cout << " Throughput: " << tflops << " TFLOPS" << std::endl;
std::cout << "========================================" << std::endl;
bool pass = true;
if(do_validation)
{
std::cout << "\nRunning validation..." << std::endl;
d_output_mem.FromDevice(h_output.data());
// Compute reference
ck_tile::HostTensor<YDataType> h_output_ref({B, output_dim});
h_output_ref.SetZero();
ck_tile::reference_mhc<XDataType, PhiDataType, YDataType, ComputeDataType, ActivationFunc>(
h_x,
h_phi,
h_output_ref,
n,
C,
r,
alpha_pre,
alpha_post,
alpha_res,
bias,
ActivationFunc{});
// Validate with appropriate tolerance for bf16
float rtol = std::is_same_v<XDataType, ck_tile::bf16_t> ? 1e-2f : 1e-3f;
float atol = std::is_same_v<XDataType, ck_tile::bf16_t> ? 1e-2f : 1e-3f;
pass = ck_tile::check_err(
h_output, h_output_ref, "Error: MHC V4 kernel output mismatch!", rtol, atol);
std::cout << "Validation: " << (pass ? "PASS" : "FAIL") << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cout << "Failed to parse arguments!" << std::endl;
return -1;
}
// Run with BF16 inputs, float output and compute
bool pass = run_mhc_benchmark<ck_tile::bf16_t, // XDataType
ck_tile::bf16_t, // PhiDataType
float, // YDataType
float, // ComputeDataType
ck_tile::element_wise::Sigmoid>(arg_parser);
return pass ? 0 : -2;
}

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"

View File

@@ -181,11 +181,11 @@ struct MHCKernelV3
auto phi_identity_func = [](auto& e, const PhiDataType& phi_val) { e = phi_val; };
auto result_tile = gemm_pipeline(make_tuple(x_dram_window),
phi_identity_func,
make_tuple(phi_dram_window),
phi_identity_func,
num_k_loops,
smem);
phi_identity_func,
make_tuple(phi_dram_window),
phi_identity_func,
num_k_loops,
smem);
// Apply normalization and activation in post-processing
// Now we divide by norm AFTER the GEMM, which means:
@@ -222,7 +222,8 @@ struct MHCKernelV3
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = alpha_post * inv_norm * 2.0f * activated_value + bias;
result_tile(i_j_idx) =
alpha_post * inv_norm * 2.0f * activated_value + bias;
}
else
{

View File

@@ -0,0 +1,255 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Manifold Constrained Hyper Connection Kernel V4:
// =====================================================================
// Optimizations implemented:
// - Remove GEMM pipeline to avoid redundant global memory reads
// - Use BlockGemm directly with manual LDS management (like v2)
// - Compute normalization incrementally during GEMM loop
// - Single pass through input data: load once, compute norm and GEMM together
namespace ck_tile {
template <typename Problem_,
typename Policy_ = MHCDefaultPolicy,
typename Activation_ = element_wise::Sigmoid>
struct MHCKernelV4
{
using Activation = ck_tile::remove_cvref_t<Activation_>;
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using PhiDataType = ck_tile::remove_cvref_t<typename Problem::PhiDataType>;
// Automatically derive tile sizes from BlockGemmShape (single source of truth!)
static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile
static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile
static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension
static constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// LDS for BlockGemm: A[kMTile, kKTile] + B[kKTile, kNTile]
constexpr index_t a_lds_size = kMTile * kKTile * sizeof(XDataType);
constexpr index_t b_lds_size = kKTile * kNTile * sizeof(PhiDataType);
return a_lds_size + b_lds_size;
}
// Grid configuration: 2D grid over (batch, output_dim)
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim)
{
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t grid_n = (output_dim + kNTile - 1) / kNTile;
return make_tuple(grid_m, grid_n);
}
CK_TILE_DEVICE void operator()(const XDataType* p_x,
const PhiDataType* p_phi,
YDataType* p_output,
index_t batch,
index_t nC,
index_t output_dim,
[[maybe_unused]] index_t n,
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f) const
{
// 2D block indexing
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
const index_t block_id = get_block_id();
const index_t block_m = block_id / grid_n_size;
const index_t block_n = block_id % grid_n_size;
const index_t batch_start = block_m * kMTile;
const index_t out_start = block_n * kNTile;
if(batch_start >= batch || out_start >= output_dim)
return;
const index_t tid = get_thread_id();
// Allocate shared memory for A and B tiles + norm accumulators + GEMM results
__shared__ char smem_ptr[GetSmemSize()];
XDataType* x_lds = reinterpret_cast<XDataType*>(smem_ptr);
PhiDataType* phi_lds =
reinterpret_cast<PhiDataType*>(smem_ptr + kMTile * kKTile * sizeof(XDataType));
// Shared memory for norm accumulation (one per batch element in tile)
__shared__ ComputeDataType sum_squares_shared[kMTile];
// Shared memory for GEMM result accumulation
__shared__ ComputeDataType result_shared[kMTile * kNTile];
// Initialize shared norm accumulators and result
for(index_t i = tid; i < kMTile; i += get_block_size())
{
sum_squares_shared[i] = 0.0f;
}
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
{
result_shared[i] = 0.0f;
}
block_sync_lds();
// Number of K-tile iterations
const index_t num_k_tiles = (nC + kKTile - 1) / kKTile;
// Main loop: load tiles, compute norms incrementally, and accumulate GEMM
for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx)
{
const index_t k_start = k_tile_idx * kKTile;
const index_t k_end = min(k_start + kKTile, nC);
const index_t k_len = k_end - k_start;
// Load X tile from global to LDS and accumulate norm
for(index_t i = tid; i < kMTile * kKTile; i += get_block_size())
{
const index_t local_m = i / kKTile;
const index_t local_k = i % kKTile;
const index_t global_m = batch_start + local_m;
const index_t global_k = k_start + local_k;
XDataType x_val = 0;
if(global_m < batch && local_k < k_len)
{
x_val = p_x[global_m * nC + global_k];
// Accumulate norm for this batch element using atomics
ComputeDataType x_compute = type_convert<ComputeDataType>(x_val);
ComputeDataType sq = x_compute * x_compute;
atomicAdd(&sum_squares_shared[local_m], sq);
}
x_lds[i] = x_val;
}
// Load Phi tile from global to LDS in column-major format (K x N)
// phi is stored in global memory as [nC, output_dim] row-major
// We need to transpose it to [K, N] column-major for BlockGemm
for(index_t i = tid; i < kKTile * kNTile; i += get_block_size())
{
const index_t local_k = i / kNTile;
const index_t local_n = i % kNTile;
const index_t global_k = k_start + local_k;
const index_t global_n = out_start + local_n;
PhiDataType phi_val = 0;
if(local_k < k_len && global_n < output_dim)
{
phi_val = p_phi[global_k * output_dim + global_n];
}
// Store in column-major: phi_lds[n * kKTile + k]
phi_lds[local_n * kKTile + local_k] = phi_val;
}
block_sync_lds();
// Perform manual GEMM: result_acc += x_lds * phi_lds^T
// Distribute work: each thread computes a subset of output elements
// With 64 threads and 16x16 output, each thread handles 4 elements
const index_t total_elements = kMTile * kNTile;
const index_t elements_per_thread =
(total_elements + get_block_size() - 1) / get_block_size();
for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx)
{
const index_t global_elem = tid * elements_per_thread + elem_idx;
if(global_elem < total_elements)
{
const index_t m_idx = global_elem / kNTile;
const index_t n_idx = global_elem % kNTile;
ComputeDataType acc = 0.0f;
for(index_t k_idx = 0; k_idx < kKTile; ++k_idx)
{
ComputeDataType x_val =
type_convert<ComputeDataType>(x_lds[m_idx * kKTile + k_idx]);
ComputeDataType phi_val =
type_convert<ComputeDataType>(phi_lds[n_idx * kKTile + k_idx]);
acc += x_val * phi_val;
}
// Accumulate to shared memory using atomics
atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc);
}
}
block_sync_lds();
}
// Ensure all norm accumulations are complete
block_sync_lds();
// Compute inverse norms after all K-tiles processed
ComputeDataType inv_norms[kMTile];
for(index_t local_m = 0; local_m < kMTile; ++local_m)
{
const index_t global_m = batch_start + local_m;
if(global_m < batch)
{
const ComputeDataType norm = ck_tile::sqrt(sum_squares_shared[local_m]) /
ck_tile::sqrt(static_cast<ComputeDataType>(nC));
inv_norms[local_m] = 1.0f / norm;
}
else
{
inv_norms[local_m] = 1.0f;
}
}
// Apply normalization, activation, and write output
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
{
const index_t local_m = i / kNTile;
const index_t local_n = i % kNTile;
const index_t global_m = batch_start + local_m;
const index_t global_n = out_start + local_n;
if(global_m >= batch || global_n >= output_dim)
continue;
const ComputeDataType inv_norm = inv_norms[local_m];
ComputeDataType value = result_shared[i];
// Apply normalization and activation based on output section
if(global_n < n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = alpha_pre * inv_norm * activated_value + bias;
}
else if(global_n < 2 * n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = alpha_post * inv_norm * 2.0f * activated_value + bias;
}
else
{
value = alpha_res * inv_norm * value + bias;
}
// Write to global memory
p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
}
}
};
} // namespace ck_tile