mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Fix layernorm kernels for wave64 GPUs
This commit is contained in:
@@ -81,7 +81,11 @@ __global__ void layer_norm_kernel
|
||||
|
||||
// Load partial sums from across warps, shuffle again across lanes
|
||||
|
||||
sum = sums[lane_id];
|
||||
#if defined(USE_ROCM)
|
||||
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
|
||||
#else
|
||||
sum = sums[lane_id];
|
||||
#endif
|
||||
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
|
||||
|
||||
// Compute mean
|
||||
@@ -116,7 +120,11 @@ __global__ void layer_norm_kernel
|
||||
|
||||
// Load partial sums from across warps, shuffle again across lanes
|
||||
|
||||
sum = sums[lane_id];
|
||||
#if defined(USE_ROCM)
|
||||
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
|
||||
#else
|
||||
sum = sums[lane_id];
|
||||
#endif
|
||||
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
|
||||
|
||||
// Get 1/sqrt(variance)
|
||||
|
||||
@@ -77,7 +77,11 @@ __global__ void rms_norm_kernel
|
||||
|
||||
// Load partial sums from across warps, shuffle again across lanes
|
||||
|
||||
sum = sums[lane_id];
|
||||
#if defined(USE_ROCM)
|
||||
sum = lane_id < NUM_WARPS ? sums[lane_id] : 0.0f;
|
||||
#else
|
||||
sum = sums[lane_id];
|
||||
#endif
|
||||
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
|
||||
|
||||
// Get norm
|
||||
|
||||
Reference in New Issue
Block a user