WIP: MHC v3

This commit is contained in:
Damien Lejeune
2026-02-05 13:04:18 +00:00
parent 6ea40157f1
commit 43a5678fdf
13 changed files with 957 additions and 41 deletions

View File

@@ -22,12 +22,12 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
HostTensor<YDataType>& output_b_out, // [B, 2n+n^2]
int n, // expansion factor
int C, // channels per stream
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f,
Activation activation = Activation{})
[[maybe_unused]] float r = 1.0f,
[[maybe_unused]] float alpha_pre = 1.0f,
[[maybe_unused]] float alpha_post = 1.0f,
[[maybe_unused]] float alpha_res = 1.0f,
[[maybe_unused]] float bias = 0.0f,
[[maybe_unused]] Activation activation = Activation{})
{
const int B = x_b_nc.get_length(0);
const int nC = n * C;
@@ -46,6 +46,7 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
// Step 2 & 3: Perform GEMM and apply elementwise operations
// TESTING: Comment out post-GEMM operations to validate GEMM only
// Process H^{pre}: x * phi[:, 0:n] -> sigma(output[:, 0:n])
for(int out_idx = 0; out_idx < n; out_idx++)
{
@@ -55,11 +56,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
}
// Step 4: Apply activation σ(H^{pre})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, out_idx) =
type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
// // Step 4: Apply activation σ(H^{pre})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, out_idx) =
// type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, out_idx) = type_convert<YDataType>(sum);
}
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
@@ -71,11 +75,14 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
}
// Step 5: Apply 2*σ(H^{post})
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, n + out_idx) =
type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
// // Step 5: Apply 2*σ(H^{post})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, n + out_idx) =
// type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
}
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
@@ -88,9 +95,12 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
}
// Apply: 1/r * alpha_res * sum + bias
output_b_out(b, 2 * n + out_idx) =
type_convert<YDataType>((alpha_res / r) * sum + bias);
// // Apply: 1/r * alpha_res * sum + bias
// output_b_out(b, 2 * n + out_idx) =
// type_convert<YDataType>((alpha_res / r) * sum + bias);
// TESTING: Store raw GEMM output
output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
}
// Note: norm is computed but not currently used in the output