Compute GEMM and normalize in one pass: MHV v3

This commit is contained in:
Damien Lejeune
2026-02-10 10:35:10 +00:00
parent 0766752704
commit 6c45f722e7

View File

@@ -97,11 +97,12 @@ struct MHCKernelV3
if(batch_start >= batch || out_start >= output_dim)
return;
// STEP 1: Compute norm BEFORE GEMM to enable future fusion
// Compute norm ||x_l||_2 / sqrt(nC) for each batch element using vectorized loads
// STEP 1: Compute norms BEFORE GEMM
// We'll use these norms in an elementwise function during GEMM's data loading
// This way we normalize X on-the-fly as it's loaded for GEMM
constexpr index_t kVectorSize = 4; // Load 4 floats at a time
ComputeDataType norms[kMTile];
ComputeDataType inv_norms[kMTile]; // Store inverse norms for efficiency
for(index_t local_m = 0; local_m < kMTile; ++local_m)
{
@@ -119,7 +120,6 @@ struct MHCKernelV3
using VecType = ext_vector_t<XDataType, kVectorSize>;
VecType vec_data = *c_style_pointer_cast<const VecType*>(row_ptr + k);
// Accumulate squares
#pragma unroll
for(index_t i = 0; i < kVectorSize; ++i)
{
@@ -135,16 +135,17 @@ struct MHCKernelV3
sum_squares += val * val;
}
norms[local_m] =
const ComputeDataType norm =
ck_tile::sqrt(sum_squares) / ck_tile::sqrt(static_cast<ComputeDataType>(nC));
inv_norms[local_m] = 1.0f / norm; // Store inverse for efficiency
}
else
{
norms[local_m] = 1.0f; // Default for out-of-bounds
inv_norms[local_m] = 1.0f; // Default for out-of-bounds
}
}
// Now setup GEMM after norm computation (better for register allocation)
// STEP 2: Setup GEMM with normalization applied during data loading
// Create full tensor views
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
@@ -168,7 +169,7 @@ struct MHCKernelV3
auto phi_dram_window = make_tile_window(
phi_tensor_padded, make_tuple(number<kNTile>{}, number<kKTile>{}), {out_start, 0});
// Use GEMM pipeline v1
// Use GEMM pipeline v1 with elementwise normalization function
using GemmPipeline = GemmPipelineAGmemBGmemCRegV1<Problem>;
const index_t num_k_loops = (nC + kKTile - 1) / kKTile;
@@ -176,10 +177,19 @@ struct MHCKernelV3
__shared__ char smem[GetSmemSize()];
auto gemm_pipeline = GemmPipeline{};
auto result_tile = gemm_pipeline(
make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem);
// Identity function for phi (no normalization needed)
auto phi_identity_func = [](auto& e, const PhiDataType& phi_val) { e = phi_val; };
// Apply elementwise operations (currently commented out for GEMM testing)
auto result_tile = gemm_pipeline(make_tuple(x_dram_window),
phi_identity_func,
make_tuple(phi_dram_window),
phi_identity_func,
num_k_loops,
smem);
// Apply normalization and activation in post-processing
// Now we divide by norm AFTER the GEMM, which means:
// result = (x * phi) / norm = x/norm * phi (mathematically equivalent)
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
@@ -197,25 +207,26 @@ struct MHCKernelV3
constexpr auto i_j_idx = make_tuple(idx0, idx1);
[[maybe_unused]] ComputeDataType value = result_tile[i_j_idx];
// Get the norm for this batch element
const ComputeDataType norm = norms[local_m];
// Get the inverse norm for this batch element
const ComputeDataType inv_norm = inv_norms[local_m];
// Apply activation based on output section
// Apply normalization and activation based on output section
// Formula: result = (1/norm) * value * alpha + bias
if(global_n < n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = (alpha_pre / norm) * activated_value + bias;
result_tile(i_j_idx) = alpha_pre * inv_norm * activated_value + bias;
}
else if(global_n < 2 * n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = (alpha_post / norm) * 2.0f * activated_value + bias;
result_tile(i_j_idx) = alpha_post * inv_norm * 2.0f * activated_value + bias;
}
else
{
result_tile(i_j_idx) = (alpha_res / norm) * value + bias;
result_tile(i_j_idx) = alpha_res * inv_norm * value + bias;
}
}
});