// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { // Reference implementation for Manifold Constrained Hyper Connection template CK_TILE_HOST void reference_mhc(const HostTensor& x_b_nc, // [B, nC] const HostTensor& phi_nc_out, // [nC, 2n+n^2] HostTensor& output_b_out, // [B, 2n+n^2] int n, // expansion factor int C, // channels per stream [[maybe_unused]] float r = 1.0f, [[maybe_unused]] float alpha_pre = 1.0f, [[maybe_unused]] float alpha_post = 1.0f, [[maybe_unused]] float alpha_res = 1.0f, [[maybe_unused]] float bias = 0.0f, [[maybe_unused]] Activation activation = Activation{}) { const int B = x_b_nc.get_length(0); const int nC = n * C; (void)nC; // May not be used in all code paths // Process each batch element auto f_batch = [&](auto b) { // Step 1: Compute norm ||x_l||_2 / sqrt(nC) ComputeDataType sum_squares = 0.0f; for(int i = 0; i < nC; i++) { ComputeDataType val = type_convert(x_b_nc(b, i)); sum_squares += val * val; } ComputeDataType norm = std::sqrt(sum_squares) / std::sqrt(static_cast(nC)); // Step 2 & 3: Perform GEMM and apply elementwise operations // TESTING: Comment out post-GEMM operations to validate GEMM only // Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n]) for(int out_idx = 0; out_idx < n; out_idx++) { ComputeDataType sum = 0.0f; for(int k = 0; k < nC; k++) { sum += type_convert(x_b_nc(b, k)) * type_convert(phi_nc_out(k, out_idx)); } // // Step 4: Apply activation σ(H^{pre}) // ComputeDataType activated_value; // activation(activated_value, sum); // output_b_out(b, out_idx) = // type_convert((alpha_pre / r) * activated_value + bias); // TESTING: Store raw GEMM output output_b_out(b, out_idx) = type_convert(sum); } // Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n]) for(int out_idx = 0; out_idx < n; out_idx++) { ComputeDataType sum = 0.0f; for(int k = 0; k < nC; k++) { sum += type_convert(x_b_nc(b, k)) * type_convert(phi_nc_out(k, n + out_idx)); } // // Step 5: Apply 2*σ(H^{post}) // ComputeDataType activated_value; // activation(activated_value, sum); // output_b_out(b, n + out_idx) = // type_convert((alpha_post / r) * 2.0f * activated_value + bias); // TESTING: Store raw GEMM output output_b_out(b, n + out_idx) = type_convert(sum); } // Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2] int n_squared = n * n; for(int out_idx = 0; out_idx < n_squared; out_idx++) { ComputeDataType sum = 0.0f; for(int k = 0; k < nC; k++) { sum += type_convert(x_b_nc(b, k)) * type_convert(phi_nc_out(k, 2 * n + out_idx)); } // // Apply: 1/r * alpha_res * sum + bias // output_b_out(b, 2 * n + out_idx) = // type_convert((alpha_res / r) * sum + bias); // TESTING: Store raw GEMM output output_b_out(b, 2 * n + out_idx) = type_convert(sum); } // Note: norm is computed but not currently used in the output // It could be used for additional normalization if needed (void)norm; }; make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency()); } } // namespace ck_tile