Fix layernorm kernels for wave64 GPUs

This commit is contained in:
turboderp
2024-06-24 02:09:46 +02:00
parent 05b1f2194e
commit 6feebfb56e
2 changed files with 15 additions and 3 deletions

View File

@@ -81,7 +81,11 @@ __global__ void layer_norm_kernel
// Load partial sums from across warps, shuffle again across lanes
sum = sums[lane_id];
#if defined(USE_ROCM)
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
#else
sum = sums[lane_id];
#endif
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
// Compute mean
@@ -116,7 +120,11 @@ __global__ void layer_norm_kernel
// Load partial sums from across warps, shuffle again across lanes
sum = sums[lane_id];
#if defined(USE_ROCM)
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
#else
sum = sums[lane_id];
#endif
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
// Get 1/sqrt(variance)

View File

@@ -77,7 +77,11 @@ __global__ void rms_norm_kernel
// Load partial sums from across warps, shuffle again across lanes
sum = sums[lane_id];
#if defined(USE_ROCM)
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
#else
sum = sums[lane_id];
#endif
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
// Get norm