mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Readd naive normalization in mhc v3
This commit is contained in:
@@ -57,13 +57,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
type_convert<ComputeDataType>(phi_nc_out(k, out_idx));
|
||||
}
|
||||
// // Step 4: Apply activation σ(H^{pre})
|
||||
// ComputeDataType activated_value;
|
||||
// activation(activated_value, sum);
|
||||
// output_b_out(b, out_idx) =
|
||||
// type_convert<YDataType>((alpha_pre / r) * activated_value + bias);
|
||||
ComputeDataType activated_value;
|
||||
activation(activated_value, sum);
|
||||
output_b_out(b, out_idx) =
|
||||
type_convert<YDataType>((alpha_pre / norm) * activated_value + bias);
|
||||
|
||||
// TESTING: Store raw GEMM output
|
||||
output_b_out(b, out_idx) = type_convert<YDataType>(sum);
|
||||
// output_b_out(b, out_idx) = type_convert<YDataType>(sum);
|
||||
}
|
||||
|
||||
// Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n])
|
||||
@@ -76,13 +76,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
type_convert<ComputeDataType>(phi_nc_out(k, n + out_idx));
|
||||
}
|
||||
// // Step 5: Apply 2*σ(H^{post})
|
||||
// ComputeDataType activated_value;
|
||||
// activation(activated_value, sum);
|
||||
// output_b_out(b, n + out_idx) =
|
||||
// type_convert<YDataType>((alpha_post / r) * 2.0f * activated_value + bias);
|
||||
ComputeDataType activated_value;
|
||||
activation(activated_value, sum);
|
||||
output_b_out(b, n + out_idx) =
|
||||
type_convert<YDataType>((alpha_post / norm) * 2.0f * activated_value + bias);
|
||||
|
||||
// TESTING: Store raw GEMM output
|
||||
output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
|
||||
// output_b_out(b, n + out_idx) = type_convert<YDataType>(sum);
|
||||
}
|
||||
|
||||
// Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2]
|
||||
@@ -95,17 +95,17 @@ CK_TILE_HOST void reference_mhc(const HostTensor<XDataType>& x_b_nc, // [B
|
||||
sum += type_convert<ComputeDataType>(x_b_nc(b, k)) *
|
||||
type_convert<ComputeDataType>(phi_nc_out(k, 2 * n + out_idx));
|
||||
}
|
||||
// // Apply: 1/r * alpha_res * sum + bias
|
||||
// output_b_out(b, 2 * n + out_idx) =
|
||||
// type_convert<YDataType>((alpha_res / r) * sum + bias);
|
||||
// Apply: 1/r * alpha_res * sum + bias
|
||||
output_b_out(b, 2 * n + out_idx) =
|
||||
type_convert<YDataType>((alpha_res / norm) * sum + bias);
|
||||
|
||||
// TESTING: Store raw GEMM output
|
||||
output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
|
||||
// output_b_out(b, 2 * n + out_idx) = type_convert<YDataType>(sum);
|
||||
}
|
||||
|
||||
// Note: norm is computed but not currently used in the output
|
||||
// It could be used for additional normalization if needed
|
||||
(void)norm;
|
||||
// (void)norm;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency());
|
||||
|
||||
@@ -145,6 +145,53 @@ struct MHCKernelV3
|
||||
auto result_tile = gemm_pipeline(
|
||||
make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem);
|
||||
|
||||
// Compute norm ||x_l||_2 / sqrt(nC) for each batch element using vectorized loads
|
||||
// Use vector loads (float4) for better memory bandwidth utilization
|
||||
constexpr index_t kVectorSize = 4; // Load 4 floats at a time
|
||||
|
||||
ComputeDataType norms[kMTile];
|
||||
|
||||
for(index_t local_m = 0; local_m < kMTile; ++local_m)
|
||||
{
|
||||
const index_t global_m = batch_start + local_m;
|
||||
if(global_m < batch)
|
||||
{
|
||||
ComputeDataType sum_squares = 0.0f;
|
||||
const XDataType* row_ptr = p_x + global_m * nC;
|
||||
|
||||
// Vectorized loop: process kVectorSize elements at a time
|
||||
index_t k = 0;
|
||||
for(; k + kVectorSize <= nC; k += kVectorSize)
|
||||
{
|
||||
// Load vector of elements
|
||||
using VecType = ext_vector_t<XDataType, kVectorSize>;
|
||||
VecType vec_data = *c_style_pointer_cast<const VecType*>(row_ptr + k);
|
||||
|
||||
// Accumulate squares
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVectorSize; ++i)
|
||||
{
|
||||
ComputeDataType val = type_convert<ComputeDataType>(vec_data[i]);
|
||||
sum_squares += val * val;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining elements (scalar loop)
|
||||
for(; k < nC; ++k)
|
||||
{
|
||||
ComputeDataType val = type_convert<ComputeDataType>(row_ptr[k]);
|
||||
sum_squares += val * val;
|
||||
}
|
||||
|
||||
norms[local_m] =
|
||||
ck_tile::sqrt(sum_squares) / ck_tile::sqrt(static_cast<ComputeDataType>(nC));
|
||||
}
|
||||
else
|
||||
{
|
||||
norms[local_m] = 1.0f; // Default for out-of-bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Apply elementwise operations (currently commented out for GEMM testing)
|
||||
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
|
||||
|
||||
@@ -163,26 +210,26 @@ struct MHCKernelV3
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
[[maybe_unused]] ComputeDataType value = result_tile[i_j_idx];
|
||||
|
||||
// TESTING: Comment out post-GEMM operations to validate GEMM only
|
||||
// // Apply activation based on output section
|
||||
// if(global_n < n)
|
||||
// {
|
||||
// ComputeDataType activated_value;
|
||||
// Activation{}(activated_value, value);
|
||||
// value = (alpha_pre / r) * activated_value + bias;
|
||||
// }
|
||||
// else if(global_n < 2 * n)
|
||||
// {
|
||||
// ComputeDataType activated_value;
|
||||
// Activation{}(activated_value, value);
|
||||
// value = (alpha_post / r) * 2.0f * activated_value + bias;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// value = (alpha_res / r) * value + bias;
|
||||
// }
|
||||
// Get the norm for this batch element
|
||||
const ComputeDataType norm = norms[local_m];
|
||||
|
||||
// p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
|
||||
// Apply activation based on output section
|
||||
if(global_n < n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
result_tile(i_j_idx) = (alpha_pre / norm) * activated_value + bias;
|
||||
}
|
||||
else if(global_n < 2 * n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
result_tile(i_j_idx) = (alpha_post / norm) * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
result_tile(i_j_idx) = (alpha_res / norm) * value + bias;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user