mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
115 lines
4.7 KiB
C++
115 lines
4.7 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||
// SPDX-License-Identifier: MIT
|
||
|
||
#pragma once
|
||
|
||
#include <cmath>
|
||
#include <thread>
|
||
#include "ck_tile/core.hpp"
|
||
#include "ck_tile/host/host_tensor.hpp"
|
||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||
|
||
namespace ck_tile {
|
||
|
||
// Reference implementation for Manifold Constrained Hyper Connection
|
||
template <typename XDataType,
|
||
typename PhiDataType,
|
||
typename YDataType,
|
||
typename ComputeDataType = float,
|
||
typename Activation = element_wise::Sigmoid>
|
||
CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B, nC]
|
||
const HostTensor<PhiDataType>& phi_nc_out, // [nC, 2n+n^2]
|
||
HostTensor<YDataType>& output_b_out, // [B, 2n+n^2]
|
||
int n, // expansion factor
|
||
int C, // channels per stream
|
||
[[maybe_unused]] float r = 1.0f,
|
||
[[maybe_unused]] float alpha_pre = 1.0f,
|
||
[[maybe_unused]] float alpha_post = 1.0f,
|
||
[[maybe_unused]] float alpha_res = 1.0f,
|
||
[[maybe_unused]] float bias = 0.0f,
|
||
[[maybe_unused]] Activation activation = Activation{})
|
||
{
|
||
const int B = x_b_nc.get_length(0);
|
||
const int nC = n * C;
|
||
(void)nC; // May not be used in all code paths
|
||
|
||
// Process each batch element
|
||
auto f_batch = [&](auto b) {
|
||
// Step 1: Compute norm ||x_l||_2 / sqrt(nC)
|
||
ComputeDataType sum_squares = 0.0f;
|
||
for(int i = 0; i < nC; i++)
|
||
{
|
||
ComputeDataType val = type_convert<ComputeDataType>(x_b_nc(b, i));
|
||
sum_squares += val * val;
|
||
}
|
||
ComputeDataType norm = std::sqrt(sum_squares) / std::sqrt(static_cast<ComputeDataType>(nC));
|
||
|
||
// Step 2 & 3: Perform GEMM and apply elementwise operations
|
||
|
||
// TESTING: Comment out post-GEMM operations to validate GEMM only
|
||
// Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n])
|
||
for(int out_idx = 0; out_idx < n; out_idx++)
|
||
{
|
||
ComputeDataType sum = 0.0f;
|
||
for(int k = 0; k < nC; k++)
|
||
{
|
||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
|
||
}
|
||
// // Step 4: Apply activation σ(H^{pre})
|
||
// ComputeDataType activated_value;
|
||
// activation(activated_value, sum);
|
||
// output_b_out(b, out_idx) =
|
||
// type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
|
||
|
||
// TESTING: Store raw GEMM output
|
||
output_b_out(b, out_idx) = type_convert<YDataType>(sum);
|
||
}
|
||
|
||
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
|
||
for(int out_idx = 0; out_idx < n; out_idx++)
|
||
{
|
||
ComputeDataType sum = 0.0f;
|
||
for(int k = 0; k < nC; k++)
|
||
{
|
||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
|
||
}
|
||
// // Step 5: Apply 2*σ(H^{post})
|
||
// ComputeDataType activated_value;
|
||
// activation(activated_value, sum);
|
||
// output_b_out(b, n + out_idx) =
|
||
// type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
|
||
|
||
// TESTING: Store raw GEMM output
|
||
output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
|
||
}
|
||
|
||
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
|
||
int n_squared = n * n;
|
||
for(int out_idx = 0; out_idx < n_squared; out_idx++)
|
||
{
|
||
ComputeDataType sum = 0.0f;
|
||
for(int k = 0; k < nC; k++)
|
||
{
|
||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
|
||
}
|
||
// // Apply: 1/r * alpha_res * sum + bias
|
||
// output_b_out(b, 2 * n + out_idx) =
|
||
// type_convert<YDataType>((alpha_res / r) * sum + bias);
|
||
|
||
// TESTING: Store raw GEMM output
|
||
output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
|
||
}
|
||
|
||
// Note: norm is computed but not currently used in the output
|
||
// It could be used for additional normalization if needed
|
||
(void)norm;
|
||
};
|
||
|
||
make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency());
|
||
}
|
||
|
||
} // namespace ck_tile
|