Files
composable_kernel/include/ck_tile/host/reference/reference_mhc.hpp
2026-02-02 02:55:17 -05:00

105 lines
4.2 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cmath>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
// Reference implementation for Manifold Constrained Hyper Connection
template <typename XDataType,
typename PhiDataType,
typename YDataType,
typename ComputeDataType = float,
typename Activation = element_wise::Sigmoid>
CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B, nC]
const HostTensor<PhiDataType>& phi_nc_out, // [nC, 2n+n^2]
HostTensor<YDataType>& output_b_out, // [B, 2n+n^2]
int n, // expansion factor
int C, // channels per stream
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f,
Activation activation = Activation{})
{
const int B = x_b_nc.get_length(0);
const int nC = n * C;
(void)nC; // May not be used in all code paths
// Process each batch element
auto f_batch = [&](auto b) {
// Step 1: Compute norm ||x_l||_2 / sqrt(nC)
ComputeDataType sum_squares = 0.0f;
for(int i = 0; i < nC; i++)
{
ComputeDataType val = type_convert<ComputeDataType>(x_b_nc(b, i));
sum_squares += val * val;
}
ComputeDataType norm = std::sqrt(sum_squares) / std::sqrt(static_cast<ComputeDataType>(nC));
// Step 2 & 3: Perform GEMM and apply elementwise operations
// Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n])
for(int out_idx = 0; out_idx < n; out_idx++)
{
ComputeDataType sum = 0.0f;
for(int k = 0; k < nC; k++)
{
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
}
// Step 4: Apply activation σ(H^{pre})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, out_idx) =
type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
}
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
for(int out_idx = 0; out_idx < n; out_idx++)
{
ComputeDataType sum = 0.0f;
for(int k = 0; k < nC; k++)
{
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
}
// Step 5: Apply 2*σ(H^{post})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, n + out_idx) =
type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
}
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
int n_squared = n * n;
for(int out_idx = 0; out_idx < n_squared; out_idx++)
{
ComputeDataType sum = 0.0f;
for(int k = 0; k < nC; k++)
{
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
}
// Apply: 1/r * alpha_res * sum + bias
output_b_out(b, 2 * n + out_idx) =
type_convert<YDataType>((alpha_res / r) * sum + bias);
}
// Note: norm is computed but not currently used in the output
// It could be used for additional normalization if needed
(void)norm;
};
make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency());
}
} // namespace ck_tile