WIP: MHC v3

This commit is contained in:
Damien Lejeune
2026-02-05 13:04:18 +00:00
parent 6ea40157f1
commit 43a5678fdf
13 changed files with 957 additions and 41 deletions

View File

@@ -0,0 +1,16 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(TARGET_NAME example_mhc)
add_executable(${TARGET_NAME} mhc.cpp)
set(TARGET_NAME example_mhc_v3)
add_executable(${TARGET_NAME} mhc_v3.cpp)
set(TARGET_NAME example_mhc_v3_single_block_test)
add_executable(${TARGET_NAME} mhc_v3_single_block_test.cpp)
set(TARGET_NAME example_mhc_v3_two_block_test)
add_executable(${TARGET_NAME} mhc_v3_two_block_test.cpp)

View File

@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/mhc.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
int main()
{
const int B = 1024; // Batch size
const int n = 4; // Expansion rate (aka streams)
const int C = 4096; // Output layer dim
const int nC = n * C; // Total input dimension
const int output_dim = 2 * n + n * n; // 2n + n^2
using ActivationFunc = ck_tile::element_wise::Sigmoid; // Example activation function
std::cout << "\n--- Testing batch size B=" << B << " (n=" << n << ", C=" << C
<< ") with activation: " << ActivationFunc::name << " ---" << std::endl;
// Allocate host tensors
ck_tile::HostTensor<float> h_x({B, nC});
ck_tile::HostTensor<float> h_phi({nC, output_dim});
ck_tile::HostTensor<float> h_output({B, output_dim});
// Initialize with random data
ck_tile::FillUniformDistribution<float>{-1.0f, 1.0f}(h_x);
ck_tile::FillUniformDistribution<float>{-0.5f, 0.5f}(h_phi);
h_output.SetZero();
// Allocate device memory
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
// Copy data to device
d_x_mem.ToDevice(h_x.data());
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
// Use template parameters for B, n, C, and Activation (compile-time)
using KernelExpansionParallel =
ck_tile::ManifoldConstrainedHyperConnectionTiled<Problem,
ck_tile::MHCDefaultPolicy,
B,
n,
C,
256,
ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize();
const ck_tile::index_t kGridSize = (output_dim + 15) / 16;
constexpr ck_tile::index_t kBlockPerCu = 1;
const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f;
// Launch kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(KernelExpansionParallel{},
kGridSize,
kBlockSize,
0,
static_cast<float*>(d_x_mem.GetDeviceBuffer()),
static_cast<float*>(d_phi_mem.GetDeviceBuffer()),
static_cast<float*>(d_output_mem.GetDeviceBuffer()),
r,
alpha_pre,
alpha_post,
alpha_res,
bias));
d_output_mem.FromDevice(h_output.data());
// Compute reference with the same activation function
ck_tile::HostTensor<float> h_output_ref({B, output_dim});
h_output_ref.SetZero();
ck_tile::reference_mhc<float, float, float, float, ActivationFunc>(h_x,
h_phi,
h_output_ref,
n,
C,
r,
alpha_pre,
alpha_post,
alpha_res,
bias,
ActivationFunc{});
// Validate
bool pass = ck_tile::check_err(
h_output, h_output_ref, "Error: Activation function mismatch!", 1e-3f, 1e-3f);
std::cout << " Result: " << (pass ? "PASS" : "FAIL") << std::endl;
if(!pass)
{
// Print first few values for debugging
std::cout << " First batch kernel output: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << " First batch reference: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output_ref(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
}
}

View File

@@ -0,0 +1,148 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/mhc.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
int main()
{
const int B = 1024; // Batch size
const int n = 4; // Expansion rate (aka streams)
const int C = 4096; // Output layer dim
const int nC = n * C; // Total input dimension
const int output_dim = 2 * n + n * n; // 2n + n^2 = 24
using ActivationFunc = ck_tile::element_wise::Sigmoid;
std::cout << "\n--- Testing MHC Kernel V3 with B=" << B << " (n=" << n << ", C=" << C << ") ---"
<< std::endl;
std::cout << "Output dimension: " << output_dim << std::endl;
// Allocate host tensors
ck_tile::HostTensor<float> h_x({B, nC});
ck_tile::HostTensor<float> h_phi({nC, output_dim});
ck_tile::HostTensor<float> h_output({B, output_dim});
// Initialize with random data
ck_tile::FillUniformDistribution<float>{-1.0f, 1.0f}(h_x);
ck_tile::FillUniformDistribution<float>{-0.5f, 0.5f}(h_phi);
h_output.SetZero();
// Allocate device memory
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
// Copy data to device
d_x_mem.ToDevice(h_x.data());
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
// V3 kernel with 2D tiling
constexpr ck_tile::index_t kMTile = 64; // Batch tile
constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4)
constexpr ck_tile::index_t kKTile = 8; // K tile for C dimension (must match BlockGemmShape::kK)
using KernelV3 = ck_tile::
MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, kMTile, kNTile, kKTile, ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelV3::BlockSize();
// 2D grid: (batch / kMTile) × (output_dim / kNTile)
auto grid_size = KernelV3::GetGridSize(B, output_dim);
const ck_tile::index_t kGridSize =
grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{});
std::cout << "Grid configuration: " << grid_size.at(ck_tile::number<0>{}) << " × "
<< grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" << std::endl;
std::cout << "Block size: " << kBlockSize << " threads" << std::endl;
std::cout << "Shared memory: " << KernelV3::GetSmemSize() << " bytes" << std::endl;
constexpr ck_tile::index_t kBlockPerCu = 1;
const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f;
// Launch kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(KernelV3{},
kGridSize,
kBlockSize,
KernelV3::GetSmemSize(),
static_cast<float*>(d_x_mem.GetDeviceBuffer()),
static_cast<float*>(d_phi_mem.GetDeviceBuffer()),
static_cast<float*>(d_output_mem.GetDeviceBuffer()),
B,
nC,
output_dim,
n,
r,
alpha_pre,
alpha_post,
alpha_res,
bias));
d_output_mem.FromDevice(h_output.data());
// Compute reference
ck_tile::HostTensor<float> h_output_ref({B, output_dim});
h_output_ref.SetZero();
ck_tile::reference_mhc<float, float, float, float, ActivationFunc>(h_x,
h_phi,
h_output_ref,
n,
C,
r,
alpha_pre,
alpha_post,
alpha_res,
bias,
ActivationFunc{});
// Validate
bool pass = ck_tile::check_err(
h_output, h_output_ref, "Error: MHC V3 kernel output mismatch!", 1e-3f, 1e-3f);
std::cout << "Result: " << (pass ? "PASS" : "FAIL") << std::endl;
if(!pass)
{
// Print first few values for debugging
std::cout << "First batch kernel output: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << "First batch reference: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output_ref(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
}
return pass ? 0 : 1;
}

View File

@@ -0,0 +1,150 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/mhc.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
int main()
{
// TESTING: Use smaller batch to fit in single block
const int B = 64; // Single block worth of batch
const int n = 4; // Expansion rate (aka streams)
const int C = 4096; // Output layer dim
const int nC = n * C; // Total input dimension
const int output_dim = 2 * n + n * n; // 2n + n^2 = 24
using ActivationFunc = ck_tile::element_wise::Sigmoid;
std::cout << "\n--- Testing MHC Kernel V3 SINGLE BLOCK with B=" << B << " (n=" << n
<< ", C=" << C << ") ---" << std::endl;
std::cout << "Output dimension: " << output_dim << std::endl;
// Allocate host tensors
ck_tile::HostTensor<float> h_x({B, nC});
ck_tile::HostTensor<float> h_phi({nC, output_dim});
ck_tile::HostTensor<float> h_output({B, output_dim});
// Initialize with random data
ck_tile::FillUniformDistribution<float>{-1.0f, 1.0f}(h_x);
ck_tile::FillUniformDistribution<float>{-0.5f, 0.5f}(h_phi);
h_output.SetZero();
// Allocate device memory
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
// Copy data to device
d_x_mem.ToDevice(h_x.data());
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
// V3 kernel with 2D tiling
constexpr ck_tile::index_t kMTile = 64; // Batch tile
constexpr ck_tile::index_t kNTile = 32; // Output tile
constexpr ck_tile::index_t kKTile = 8; // K tile
using KernelV3 = ck_tile::
MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, kMTile, kNTile, kKTile, ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelV3::BlockSize();
// Should be exactly 1 block
auto grid_size = KernelV3::GetGridSize(B, output_dim);
const ck_tile::index_t kGridSize =
grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{});
std::cout << "Grid configuration: " << grid_size.at(ck_tile::number<0>{}) << " × "
<< grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks (SHOULD BE 1)"
<< std::endl;
std::cout << "Block size: " << kBlockSize << " threads" << std::endl;
std::cout << "Shared memory: " << KernelV3::GetSmemSize() << " bytes" << std::endl;
constexpr ck_tile::index_t kBlockPerCu = 1;
const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f;
// Launch kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(KernelV3{},
kGridSize,
kBlockSize,
KernelV3::GetSmemSize(),
static_cast<float*>(d_x_mem.GetDeviceBuffer()),
static_cast<float*>(d_phi_mem.GetDeviceBuffer()),
static_cast<float*>(d_output_mem.GetDeviceBuffer()),
B,
nC,
output_dim,
n,
r,
alpha_pre,
alpha_post,
alpha_res,
bias));
d_output_mem.FromDevice(h_output.data());
// Compute reference
ck_tile::HostTensor<float> h_output_ref({B, output_dim});
h_output_ref.SetZero();
ck_tile::reference_mhc<float, float, float, float, ActivationFunc>(h_x,
h_phi,
h_output_ref,
n,
C,
r,
alpha_pre,
alpha_post,
alpha_res,
bias,
ActivationFunc{});
// Validate
bool pass = ck_tile::check_err(
h_output, h_output_ref, "Error: MHC V3 SINGLE BLOCK kernel output mismatch!", 1e-3f, 1e-3f);
std::cout << "Result: " << (pass ? "PASS" : "FAIL") << std::endl;
if(!pass)
{
// Print first few values for debugging
std::cout << "First batch kernel output: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << "First batch reference: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output_ref(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
}
return pass ? 0 : 1;
}

View File

@@ -0,0 +1,168 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <cmath>
#include <tuple>
#include <iostream>
#include <cstring>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/mhc.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/reference/reference_mhc.hpp"
#include "ck_tile/host/check_err.hpp"
int main()
{
// TESTING: Use 2 blocks worth of batch
const int B = 128; // Two blocks worth of batch
const int n = 4; // Expansion rate (aka streams)
const int C = 4096; // Output layer dim
const int nC = n * C; // Total input dimension
const int output_dim = 2 * n + n * n; // 2n + n^2 = 24
using ActivationFunc = ck_tile::element_wise::Sigmoid;
std::cout << "\n--- Testing MHC Kernel V3 TWO BLOCKS with B=" << B << " (n=" << n << ", C=" << C
<< ") ---" << std::endl;
std::cout << "Output dimension: " << output_dim << std::endl;
// Allocate host tensors
ck_tile::HostTensor<float> h_x({B, nC});
ck_tile::HostTensor<float> h_phi({nC, output_dim});
ck_tile::HostTensor<float> h_output({B, output_dim});
// Initialize with random data
ck_tile::FillUniformDistribution<float>{-1.0f, 1.0f}(h_x);
ck_tile::FillUniformDistribution<float>{-0.5f, 0.5f}(h_phi);
h_output.SetZero();
// Allocate device memory
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
// Copy data to device
d_x_mem.ToDevice(h_x.data());
d_phi_mem.ToDevice(h_phi.data());
d_output_mem.ToDevice(h_output.data());
// Define block shape
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 256>,
ck_tile::sequence<1, 1>>;
using Problem = ck_tile::MHCProblem<float, float, float, BlockShape>;
// V3 kernel with 2D tiling
constexpr ck_tile::index_t kMTile = 64; // Batch tile
constexpr ck_tile::index_t kNTile = 32; // Output tile
constexpr ck_tile::index_t kKTile = 8; // K tile
using KernelV3 = ck_tile::
MHCKernelV3<Problem, ck_tile::MHCDefaultPolicy, kMTile, kNTile, kKTile, ActivationFunc>;
const ck_tile::index_t kBlockSize = KernelV3::BlockSize();
// Should be exactly 2 blocks
auto grid_size = KernelV3::GetGridSize(B, output_dim);
const ck_tile::index_t kGridSize =
grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{});
std::cout << "Grid configuration: " << grid_size.at(ck_tile::number<0>{}) << " × "
<< grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks (SHOULD BE 2)"
<< std::endl;
std::cout << "Block size: " << kBlockSize << " threads" << std::endl;
std::cout << "Shared memory: " << KernelV3::GetSmemSize() << " bytes" << std::endl;
constexpr ck_tile::index_t kBlockPerCu = 1;
const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f;
// Launch kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(KernelV3{},
kGridSize,
kBlockSize,
KernelV3::GetSmemSize(),
static_cast<float*>(d_x_mem.GetDeviceBuffer()),
static_cast<float*>(d_phi_mem.GetDeviceBuffer()),
static_cast<float*>(d_output_mem.GetDeviceBuffer()),
B,
nC,
output_dim,
n,
r,
alpha_pre,
alpha_post,
alpha_res,
bias));
d_output_mem.FromDevice(h_output.data());
// Compute reference
ck_tile::HostTensor<float> h_output_ref({B, output_dim});
h_output_ref.SetZero();
ck_tile::reference_mhc<float, float, float, float, ActivationFunc>(h_x,
h_phi,
h_output_ref,
n,
C,
r,
alpha_pre,
alpha_post,
alpha_res,
bias,
ActivationFunc{});
// Validate
bool pass = ck_tile::check_err(
h_output, h_output_ref, "Error: MHC V3 TWO BLOCKS kernel output mismatch!", 1e-3f, 1e-3f);
std::cout << "Result: " << (pass ? "PASS" : "FAIL") << std::endl;
if(!pass)
{
// Print first and second batch values for debugging
std::cout << "\nFirst batch (0) kernel output: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << "First batch (0) reference: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output_ref(0, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << "\nSecond block first batch (64) kernel output: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output(64, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
std::cout << "Second block first batch (64) reference: [";
for(int i = 0; i < std::min(8, output_dim); i++)
{
std::cout << h_output_ref(64, i);
if(i < std::min(8, output_dim) - 1)
std::cout << ", ";
}
std::cout << " ...]" << std::endl;
}
return pass ? 0 : 1;
}

View File

@@ -30,4 +30,5 @@ add_subdirectory(36_pooling)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(42_mhc)

View File

@@ -22,12 +22,12 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
HostTensor<YDataType>& output_b_out, // [B, 2n+n^2]
int n, // expansion factor
int C, // channels per stream
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f,
Activation activation = Activation{})
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f,
[[maybe_unused]] Activation activation = Activation{})
{
const int B = x_b_nc.get_length(0);
const int nC = n * C;
@@ -46,6 +46,7 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
// Step 2 & 3: Perform GEMM and apply elementwise operations
// TESTING: Comment out post-GEMM operations to validate GEMM only
// Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n])
for(int out_idx = 0; out_idx < n; out_idx++)
{
@@ -55,11 +56,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
}
// Step 4: Apply activation σ(H^{pre})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, out_idx) =
type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
// // Step 4: Apply activation σ(H^{pre})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, out_idx) =
// type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, out_idx) = type_convert<YDataType>(sum);
}
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
@@ -71,11 +75,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
}
// Step 5: Apply 2*σ(H^{post})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, n + out_idx) =
type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
// // Step 5: Apply 2*σ(H^{post})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, n + out_idx) =
// type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
}
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
@@ -88,9 +95,12 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
}
// Apply: 1/r * alpha_res * sum + bias
output_b_out(b, 2 * n + out_idx) =
type_convert<YDataType>((alpha_res / r) * sum + bias);
// // Apply: 1/r * alpha_res * sum + bias
// output_b_out(b, 2 * n + out_idx) =
// type_convert<YDataType>((alpha_res / r) * sum + bias);
// TESTING: Store raw GEMM output
output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
}
// Note: norm is computed but not currently used in the output

View File

@@ -5,9 +5,11 @@
#include "ck_tile/ops/mhc/kernel/mhc_kernel.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp"
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -0,0 +1,215 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Manifold Constrained Hyper Connection Kernel V3:
// =====================================================================
// Optimizations implemented:
// - Step 2.b: 2D tiling parallelization (batch × output_dim)
// - Step 3: No output_dim tiling (all 26 outputs in one block)
// - Step 4: Use CK-tile GEMM pipeline for proper memory handling
namespace ck_tile {
template <typename Problem_,
typename Policy_ = MHCDefaultPolicy,
index_t kMTile_ = 64, // Batch tile size
index_t kNTile_ = 32, // Output dimension tile (can cover all 26 outputs)
index_t kKTile_ = 8, // K-tile for C dimension (must match BlockGemmShape::kK)
typename Activation_ = element_wise::Sigmoid>
struct MHCKernelV3
{
using Activation = ck_tile::remove_cvref_t<Activation_>;
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using PhiDataType = ck_tile::remove_cvref_t<typename Problem::PhiDataType>;
static constexpr index_t kMTile = kMTile_; // Batch tile
static constexpr index_t kNTile = kNTile_; // Output tile
static constexpr index_t kKTile = kKTile_; // K tile for C dimension
static constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// Calculate shared memory size based on BlockGemmShape
// The pipeline needs LDS for A[kM, kK] and B[kK, kN]
constexpr index_t kM = Problem::BlockGemmShape::kM;
constexpr index_t kN = Problem::BlockGemmShape::kN;
constexpr index_t kK = Problem::BlockGemmShape::kK;
// Approximate LDS size (actual calculation is complex, but this is a safe upper bound)
constexpr index_t a_lds_size = kM * kK * sizeof(XDataType) * 2;
constexpr index_t b_lds_size = kN * kK * sizeof(PhiDataType) * 2;
return a_lds_size + b_lds_size;
}
// Grid configuration: 2D grid over (batch, output_dim)
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim)
{
const index_t grid_m = (batch + kMTile - 1) / kMTile;
const index_t grid_n = (output_dim + kNTile - 1) / kNTile;
return make_tuple(grid_m, grid_n);
}
CK_TILE_DEVICE void operator()(const XDataType* p_x,
const PhiDataType* p_phi,
YDataType* p_output,
index_t batch,
index_t nC,
index_t output_dim,
[[maybe_unused]] index_t n,
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f) const
{
// 2D block indexing
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
const index_t block_m = get_block_id() / grid_n_size;
const index_t block_n = get_block_id() % grid_n_size;
const index_t batch_start = block_m * kMTile;
const index_t out_start = block_n * kNTile;
if(batch_start >= batch || out_start >= output_dim)
return;
// Create tensor views with adjusted pointers and dimensions
// The GEMM pipeline expects windows with origin {0,0} relative to the tensor view
const index_t remaining_batch = batch - batch_start;
const index_t remaining_output = output_dim - out_start;
auto x_tensor_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_x + batch_start * nC, // Adjust pointer to start at this block's batch range
make_tuple(remaining_batch, nC), // Dimensions from this block's starting point
make_tuple(nC, 1),
number<1>{},
number<1>{});
auto phi_tensor_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_phi + out_start, // Adjust pointer to start at this block's output range
make_tuple(nC, remaining_output), // Dimensions from this block's starting point
make_tuple(remaining_output, 1),
number<1>{},
number<1>{});
// Pad tensors to tile sizes to handle boundary conditions
auto x_tensor = pad_tensor_view(
x_tensor_unpadded, make_tuple(number<kMTile>{}, number<kKTile>{}), sequence<0, 1>{});
auto phi_tensor = pad_tensor_view(
phi_tensor_unpadded, make_tuple(number<kKTile>{}, number<kNTile>{}), sequence<0, 1>{});
// Create DRAM tile windows with origin {0, 0} relative to the padded tensor views
// The pipeline will internally manage K-dimension iteration
auto x_dram_window =
make_tile_window(x_tensor,
make_tuple(number<kMTile>{}, number<kKTile>{}),
{0, 0}); // Origin at {0, 0} relative to the padded tensor view
auto phi_dram_window =
make_tile_window(phi_tensor,
make_tuple(number<kKTile>{}, number<kNTile>{}),
{0, 0}); // Origin at {0, 0} relative to the padded tensor view
// Use GEMM pipeline v3 to compute the full GEMM
using GemmPipeline = GemmPipelineAgBgCrCompV3<Problem>;
const index_t num_k_loops = (nC + kKTile - 1) / kKTile;
extern __shared__ char smem[];
auto gemm_pipeline = GemmPipeline{};
// V3 pipeline expects non-tuple windows and uses identity functions internally
auto result_tile = gemm_pipeline(x_dram_window, phi_dram_window, num_k_loops, smem);
// Apply elementwise operations (currently commented out for GEMM testing)
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
result_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const index_t local_m = tile_idx.at(number<0>{});
const index_t local_n = tile_idx.at(number<1>{});
const index_t global_m = batch_start + local_m;
const index_t global_n = out_start + local_n;
if(global_m < batch && global_n < output_dim)
{
constexpr auto i_j_idx = make_tuple(idx0, idx1);
[[maybe_unused]] ComputeDataType value = result_tile[i_j_idx];
// TESTING: Comment out post-GEMM operations to validate GEMM only
// // Apply activation based on output section
// if(global_n < n)
// {
// ComputeDataType activated_value;
// Activation{}(activated_value, value);
// value = (alpha_pre / r) * activated_value + bias;
// }
// else if(global_n < 2 * n)
// {
// ComputeDataType activated_value;
// Activation{}(activated_value, value);
// value = (alpha_post / r) * 2.0f * activated_value + bias;
// }
// else
// {
// value = (alpha_res / r) * value + bias;
// }
// p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
}
});
});
// Cast result to output data type
auto result_output = cast_tile<YDataType>(result_tile);
// Create output tensor view for efficient store_tile operation
constexpr index_t output_vector_size = 16 / sizeof(YDataType);
auto output_tensor_view_unpadded = make_naive_tensor_view<address_space_enum::global>(
p_output + batch_start * output_dim +
out_start, // Adjust pointer to this block's output region
make_tuple(remaining_batch,
remaining_output), // Dimensions from this block's starting point
make_tuple(output_dim, 1), // Strides: row-major layout
number<output_vector_size>{}, // Vector size for efficient memory access
number<1>{}); // Alignment
// Pad output tensor view to match the tile size (for boundary handling)
auto output_tensor_view = pad_tensor_view(output_tensor_view_unpadded,
make_tuple(number<kMTile>{}, number<kNTile>{}),
sequence<0, 1>{});
// Create tile window for the output using result_output's distribution
auto output_window = make_tile_window(
output_tensor_view,
make_tuple(number<kMTile>{}, number<kNTile>{}),
{0, 0}, // Origin at {0, 0} relative to the padded view
result_output.get_tile_distribution()); // Use distribution from result_output
// Store the result using the tile window (padding will prevent out-of-bounds writes)
store_tile(output_window, result_output);
}
};
} // namespace ck_tile

View File

@@ -13,6 +13,7 @@ namespace ck_tile {
// This policy provides warp gemm configuration for MHC operations
struct MHCDefaultPolicy
{
// Provide warp gemm configuration for float data types
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()

View File

@@ -4,21 +4,16 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
namespace ck_tile {
// Simple GEMM shape for MHC operations
// This provides the kM, kN, kK members that BlockGemm expects
// GEMM shape for MHC operations
// This provides the kM, kN, kK members and warp configuration
template <index_t M_, index_t N_, index_t K_>
struct MHCGemmShape
{
static constexpr index_t kM = M_;
static constexpr index_t kN = N_;
static constexpr index_t kK = K_;
// For compatibility with BlockGemm
static constexpr index_t NumWarps = 1; // Simple: 1 warp for now
static constexpr index_t kBlockSize = 256; // Block size
};
using MHCGemmShape =
TileGemmShape<sequence<M_, N_, K_>, // BlockTile
sequence<1, 1, 1>, // BlockWarps (1 warp in M, N, K)
sequence<M_, N_, K_>>; // WarpTile (same as block tile for single warp)
} // namespace ck_tile

View File

@@ -26,22 +26,54 @@ struct MHCProblem
using CDataType = ComputeDataType; // Output/accumulator matrix C
// BlockGemmShape with kM, kN, kK members for BlockGemm
// BlockGemm expects windows to match exactly: A[kM, kK], B[kK, kN]
// Our windows: x[16, 256], phi[256, 16]
// Try matching to warp gemm size: kM=16, kN=16, kK=16
// We'll need to iterate over K dimension
using BlockGemmShape = MHCGemmShape<16, 16, 16>;
// Keep original BlockShape for other uses
// using BlockShape is already defined above
// Use supported warp gemm configuration for float32: 32x32x8
// We'll use 2 warps in M and 1 warp in N to get 64x32 block
using BlockGemmShape =
TileGemmShape<sequence<64, 32, 8>, // BlockTile (M, N, K)
sequence<2, 1, 1>, // BlockWarps (2 warps in M, 1 in N, 1 in K)
sequence<32, 32, 8>>; // WarpTile (matches available float32 MFMA)
// Layout types for BlockGemm
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [1, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, n]
using ALayout = ck_tile::tensor_layout::gemm::RowMajor; // x is row-major [B, nC]
using BLayout = ck_tile::tensor_layout::gemm::RowMajor; // phi is row-major [nC, output_dim]
using CLayout = ck_tile::tensor_layout::gemm::RowMajor; // output is row-major
// For GEMM pipeline compatibility
using AsDataTypeTuple = tuple<ADataType>;
using BsDataTypeTuple = tuple<BDataType>;
using AsLayoutTuple = tuple<ALayout>;
using BsLayoutTuple = tuple<BLayout>;
using AElementWise = identity;
using BElementWise = identity;
static constexpr bool TransposeC = false;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false; // TESTING: Disable N padding
static constexpr bool kPadK = false;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t VectorLoadSize = 16;
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// kBlockSize for BlockGemm compatibility
static constexpr index_t kBlockSize = BlockShape::BlockSize;
// Additional traits required by v3 pipeline
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool FixedVectorSize = false;
struct Traits
{
static constexpr bool UsePersistentKernel = false;
};
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
};
} // namespace ck_tile

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
struct MHCShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile