mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add last steps: activations functions
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <thread>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,17 +15,19 @@ namespace ck_tile {
|
||||
template <typename XDataType,
|
||||
typename PhiDataType,
|
||||
typename YDataType,
|
||||
typename ComputeDataType = float>
|
||||
typename ComputeDataType = float,
|
||||
typename Activation = element_wise::Sigmoid>
|
||||
CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B, nC]
|
||||
const HostTensor<PhiDataType>& phi_nc_out, // [nC, 2n+n²]
|
||||
HostTensor<YDataType>& output_b_out, // [B, 2n+n²]
|
||||
const HostTensor<PhiDataType>& phi_nc_out, // [nC, 2n+n^2]
|
||||
HostTensor<YDataType>& output_b_out, // [B, 2n+n^2]
|
||||
int n, // expansion factor
|
||||
int C, // channels per stream
|
||||
[[maybe_unused]] float r = 1.0f,
|
||||
[[maybe_unused]] float alpha_pre = 1.0f,
|
||||
[[maybe_unused]] float alpha_post = 1.0f,
|
||||
[[maybe_unused]] float alpha_res = 1.0f,
|
||||
[[maybe_unused]] float bias = 0.0f)
|
||||
[[maybe_unused]] float bias = 0.0f,
|
||||
Activation activation = Activation{})
|
||||
{
|
||||
const int B = x_b_nc.get_length(0);
|
||||
const int nC = n * C;
|
||||
@@ -43,7 +46,7 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
|
||||
// Step 2 & 3: Perform GEMM and apply elementwise operations
|
||||
|
||||
// Process H^{pre}: x * phi[:, 0:n] -> output[:, 0:n]
|
||||
// Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n])
|
||||
for(int out_idx = 0; out_idx < n; out_idx++)
|
||||
{
|
||||
ComputeDataType sum = 0.0f;
|
||||
@@ -52,11 +55,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||||
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
|
||||
}
|
||||
// Apply: 1/r * alpha_pre * sum + bias
|
||||
output_b_out(b, out_idx) = type_convert<YDataType>((alpha_pre / r) * sum + bias);
|
||||
// Step 4: Apply activation σ(H^{pre})
|
||||
ComputeDataType activated_value;
|
||||
activation(activated_value, sum);
|
||||
output_b_out(b, out_idx) =
|
||||
type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
|
||||
}
|
||||
|
||||
// Process H^{post}: x * phi[:, n:2n] -> output[:, n:2n]
|
||||
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
|
||||
for(int out_idx = 0; out_idx < n; out_idx++)
|
||||
{
|
||||
ComputeDataType sum = 0.0f;
|
||||
@@ -65,11 +71,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||||
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
|
||||
}
|
||||
// Apply: 1/r * alpha_post * sum + bias
|
||||
output_b_out(b, n + out_idx) = type_convert<YDataType>((alpha_post / r) * sum + bias);
|
||||
// Step 5: Apply 2*σ(H^{post})
|
||||
ComputeDataType activated_value;
|
||||
activation(activated_value, sum);
|
||||
output_b_out(b, n + out_idx) =
|
||||
type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
|
||||
}
|
||||
|
||||
// Process H^{res}: x * phi[:, 2n:2n+n²] -> output[:, 2n:2n+n²]
|
||||
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
|
||||
int n_squared = n * n;
|
||||
for(int out_idx = 0; out_idx < n_squared; out_idx++)
|
||||
{
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Manifold Constrained Hyper Connection Kernel (True CK Tile Version):
|
||||
// =====================================================================
|
||||
@@ -20,15 +21,17 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_,
|
||||
typename Policy_ = MHCDefaultPolicy,
|
||||
index_t B_ = 16, // Batch size (compile-time)
|
||||
index_t N_ = 4, // Expansion factor (compile-time)
|
||||
index_t C_ = 64, // Channels per stream (compile-time)
|
||||
index_t KTile_ = 256> // K-tile size for shared memory (compile-time)
|
||||
typename Policy_ = MHCDefaultPolicy,
|
||||
index_t B_ = 16, // Batch size (compile-time)
|
||||
index_t N_ = 4, // Expansion factor (compile-time)
|
||||
index_t C_ = 64, // Channels per stream (compile-time)
|
||||
index_t KTile_ = 256, // K-tile size for shared memory (compile-time)
|
||||
typename Activation_ = element_wise::Sigmoid> // Activation function (compile-time)
|
||||
struct ManifoldConstrainedHyperConnectionTiled
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
using Activation = ck_tile::remove_cvref_t<Activation_>;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
@@ -241,16 +244,32 @@ struct ManifoldConstrainedHyperConnectionTiled
|
||||
|
||||
if(global_batch < B && global_col < output_dim)
|
||||
{
|
||||
// Determine alpha based on the actual output column
|
||||
float alpha = (global_col < kN) ? alpha_pre
|
||||
: (global_col < 2 * kN) ? alpha_post
|
||||
: alpha_res;
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ComputeDataType value = result_tile[i_j_idx];
|
||||
|
||||
// Step 4 & 5: Apply activation functions based on output section
|
||||
if(global_col < kN)
|
||||
{
|
||||
// H^{pre}: Apply sigma(H^{pre})
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = (alpha_pre / r) * activated_value + bias;
|
||||
}
|
||||
else if(global_col < 2 * kN)
|
||||
{
|
||||
// H^{post}: Apply 2*sigma(H^{post})
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = (alpha_post / r) * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
// H^{res}: No activation (will be Sinkhorn-Knopp later)
|
||||
value = (alpha_res / r) * value + bias;
|
||||
}
|
||||
|
||||
// Apply scaling and bias, then store: result = (alpha / r) * result + bias
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const index_t global_idx = global_batch * output_dim + global_col;
|
||||
p_output[global_idx] =
|
||||
type_convert<YDataType>((alpha / r) * result_tile[i_j_idx] + bias);
|
||||
p_output[global_idx] = type_convert<YDataType>(value);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user