Readd naive normalization in mhc v3

This commit is contained in:
Damien Lejeune
2026-02-06 09:44:20 +00:00
parent 053aed9402
commit e7ebd6c288
4 changed files with 704 additions and 543 deletions

View File

@@ -57,13 +57,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
}
// // Step 4: Apply activation σ(H^{pre})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, out_idx) =
// type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, out_idx) =
type_convert<YDataType>((alpha_pre / norm) * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, out_idx) = type_convert<YDataType>(sum);
// output_b_out(b, out_idx) = type_convert<YDataType>(sum);
}
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
@@ -76,13 +76,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
}
// // Step 5: Apply 2*σ(H^{post})
// ComputeDataType activated_value;
// activation(activated_value, sum);
// output_b_out(b, n + out_idx) =
// type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
ComputeDataType activated_value;
activation(activated_value, sum);
output_b_out(b, n + out_idx) =
type_convert<YDataType>((alpha_post / norm) * 2.0f * activated_value + bias);
// TESTING: Store raw GEMM output
output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
// output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
}
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
@@ -95,17 +95,17 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
}
// // Apply: 1/r * alpha_res * sum + bias
// output_b_out(b, 2 * n + out_idx) =
// type_convert<YDataType>((alpha_res / r) * sum + bias);
// Apply: 1/r * alpha_res * sum + bias
output_b_out(b, 2 * n + out_idx) =
type_convert<YDataType>((alpha_res / norm) * sum + bias);
// TESTING: Store raw GEMM output
output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
// output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
}
// Note: norm is computed but not currently used in the output
// It could be used for additional normalization if needed
(void)norm;
// (void)norm;
};
make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency());

View File

@@ -145,6 +145,53 @@ struct MHCKernelV3
auto result_tile = gemm_pipeline(
make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem);
// Compute norm ||x_l||_2 / sqrt(nC) for each batch element using vectorized loads
// Use vector loads (float4) for better memory bandwidth utilization
constexpr index_t kVectorSize = 4; // Load 4 floats at a time
ComputeDataType norms[kMTile];
for(index_t local_m = 0; local_m < kMTile; ++local_m)
{
const index_t global_m = batch_start + local_m;
if(global_m < batch)
{
ComputeDataType sum_squares = 0.0f;
const XDataType* row_ptr = p_x + global_m * nC;
// Vectorized loop: process kVectorSize elements at a time
index_t k = 0;
for(; k + kVectorSize <= nC; k += kVectorSize)
{
// Load vector of elements
using VecType = ext_vector_t<XDataType, kVectorSize>;
VecType vec_data = *c_style_pointer_cast<const VecType*>(row_ptr + k);
// Accumulate squares
#pragma unroll
for(index_t i = 0; i < kVectorSize; ++i)
{
ComputeDataType val = type_convert<ComputeDataType>(vec_data[i]);
sum_squares += val * val;
}
}
// Handle remaining elements (scalar loop)
for(; k < nC; ++k)
{
ComputeDataType val = type_convert<ComputeDataType>(row_ptr[k]);
sum_squares += val * val;
}
norms[local_m] =
ck_tile::sqrt(sum_squares) / ck_tile::sqrt(static_cast<ComputeDataType>(nC));
}
else
{
norms[local_m] = 1.0f; // Default for out-of-bounds
}
}
// Apply elementwise operations (currently commented out for GEMM testing)
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
@@ -163,26 +210,26 @@ struct MHCKernelV3
constexpr auto i_j_idx = make_tuple(idx0, idx1);
[[maybe_unused]] ComputeDataType value = result_tile[i_j_idx];
// TESTING: Comment out post-GEMM operations to validate GEMM only
// // Apply activation based on output section
// if(global_n < n)
// {
// ComputeDataType activated_value;
// Activation{}(activated_value, value);
// value = (alpha_pre / r) * activated_value + bias;
// }
// else if(global_n < 2 * n)
// {
// ComputeDataType activated_value;
// Activation{}(activated_value, value);
// value = (alpha_post / r) * 2.0f * activated_value + bias;
// }
// else
// {
// value = (alpha_res / r) * value + bias;
// }
// Get the norm for this batch element
const ComputeDataType norm = norms[local_m];
// p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
// Apply activation based on output section
if(global_n < n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = (alpha_pre / norm) * activated_value + bias;
}
else if(global_n < 2 * n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = (alpha_post / norm) * 2.0f * activated_value + bias;
}
else
{
result_tile(i_j_idx) = (alpha_res / norm) * value + bias;
}
}
});
});