Add last steps: activations functions

This commit is contained in:
Damien Lejeune
2026-01-29 08:30:45 -05:00
parent da895cdd88
commit 6ea40157f1
4 changed files with 190 additions and 26 deletions

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Manifold Constrained Hyper Connection Kernel (True CK Tile Version):
// =====================================================================
@@ -20,15 +21,17 @@
namespace ck_tile {
template <typename Problem_,
typename Policy_ = MHCDefaultPolicy,
index_t B_ = 16, // Batch size (compile-time)
index_t N_ = 4, // Expansion factor (compile-time)
index_t C_ = 64, // Channels per stream (compile-time)
index_t KTile_ = 256> // K-tile size for shared memory (compile-time)
typename Policy_ = MHCDefaultPolicy,
index_t B_ = 16, // Batch size (compile-time)
index_t N_ = 4, // Expansion factor (compile-time)
index_t C_ = 64, // Channels per stream (compile-time)
index_t KTile_ = 256, // K-tile size for shared memory (compile-time)
typename Activation_ = element_wise::Sigmoid> // Activation function (compile-time)
struct ManifoldConstrainedHyperConnectionTiled
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using Activation = ck_tile::remove_cvref_t<Activation_>;
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
@@ -241,16 +244,32 @@ struct ManifoldConstrainedHyperConnectionTiled
if(global_batch < B && global_col < output_dim)
{
// Determine alpha based on the actual output column
float alpha = (global_col < kN) ? alpha_pre
: (global_col < 2 * kN) ? alpha_post
: alpha_res;
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ComputeDataType value = result_tile[i_j_idx];
// Step 4 & 5: Apply activation functions based on output section
if(global_col < kN)
{
// H^{pre}: Apply sigma(H^{pre})
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = (alpha_pre / r) * activated_value + bias;
}
else if(global_col < 2 * kN)
{
// H^{post}: Apply 2*sigma(H^{post})
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = (alpha_post / r) * 2.0f * activated_value + bias;
}
else
{
// H^{res}: No activation (will be Sinkhorn-Knopp later)
value = (alpha_res / r) * value + bias;
}
// Apply scaling and bias, then store: result = (alpha / r) * result + bias
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const index_t global_idx = global_batch * output_dim + global_col;
p_output[global_idx] =
type_convert<YDataType>((alpha / r) * result_tile[i_j_idx] + bias);
p_output[global_idx] = type_convert<YDataType>(value);
}
});
});