From f5571e241e6e2624db19259608e099dca9f522a5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 21 Oct 2025 08:12:48 +0300 Subject: [PATCH] cuda: use better block sizes for rms_norm (#845) * cuda: use better block sizes for rms_norm * Minor * Remove forgotten printf --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/norm.cu | 43 +++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 9e4931a3..6c3e565b 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -119,7 +119,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; tmp = warp_reduce_sum(tmp); } @@ -198,7 +198,7 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; tmp = warp_reduce_sum(tmp); } @@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc( const int row = blockIdx.x; const int channel = blockIdx.y; + //const int channel = blockIdx.y * blockDim.y + threadIdx.y; const int sample = blockIdx.z; const int tid = threadIdx.x; @@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc( } __syncthreads(); tmp = s_sum[lane_id]; + //if constexpr (block_size == 1024) { + // tmp = s_sum[lane_id]; + //} else { + // tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; + //} tmp = warp_reduce_sum(tmp); } @@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); + constexpr int kBlockSize = 256; if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); + const dim3 block_dims(kBlockSize, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024><<>>(x, dst, ncols, eps); @@ -302,10 +309,22 @@ static void rms_norm_f32_nc_cuda( static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + constexpr int kBlockSize = 256; GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); + if (ncols < kBlockSize) { + switch (ncols) { + case 32: fused_rms_norm_f32< 32><<>>(x, y, dst, ncols, eps); break; + case 64: fused_rms_norm_f32< 64><<>>(x, y, dst, ncols, eps); break; + case 96: fused_rms_norm_f32< 96><<>>(x, y, dst, ncols, eps); break; + case 128: fused_rms_norm_f32<128><<>>(x, y, dst, ncols, eps); break; + case 160: fused_rms_norm_f32<160><<>>(x, y, dst, ncols, eps); break; + case 192: fused_rms_norm_f32<192><<>>(x, y, dst, ncols, eps); break; + default : fused_rms_norm_f32<224><<>>(x, y, dst, ncols, eps); break; + } + } + else if (ncols < 1024) { + const dim3 block_dims(kBlockSize, 1, 1); + fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); fused_rms_norm_f32<1024><<>>(x, y, dst, ncols, eps); @@ -319,6 +338,16 @@ static void fused_rms_norm_f32_nc_cuda( if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //constexpr int kBlockSize = 256; + + //if (nchannels%4 == 0) { + // const dim3 blocks_num(nrows, nchannels/4, nsamples); + // const dim3 block_dims(kBlockSize, 4, 1); + // fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //} else { + // const dim3 block_dims(kBlockSize, 1, 1); + // fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //} } else { const dim3 block_dims(1024, 1, 1); fused_rms_norm_f32_nc<1024><<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);