From 0a9752db6cd899c3bf5437a5d2be8789391d7627 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 20 Oct 2025 16:12:13 +0300 Subject: [PATCH] cuda: use better block sizes for rms_norm --- ggml/src/ggml-cuda/norm.cu | 40 +++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 9e4931a3..3cafdf6d 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc( const int row = blockIdx.x; const int channel = blockIdx.y; + //const int channel = blockIdx.y * blockDim.y + threadIdx.y; const int sample = blockIdx.z; const int tid = threadIdx.x; @@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc( } __syncthreads(); tmp = s_sum[lane_id]; + //if constexpr (block_size == 1024) { + // tmp = s_sum[lane_id]; + //} else { + // tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; + //} tmp = warp_reduce_sum(tmp); } @@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); + constexpr int kBlockSize = 256; if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); + const dim3 block_dims(kBlockSize, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024><<>>(x, dst, ncols, eps); @@ -290,6 +297,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con static void rms_norm_f32_nc_cuda( const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + printf("%s: ncols = %d\n", __func__, ncols); const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); @@ -302,10 +310,22 @@ static void rms_norm_f32_nc_cuda( static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + constexpr int kBlockSize = 256; GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); + if (ncols < kBlockSize) { + switch (ncols) { + case 32: fused_rms_norm_f32< 32><<>>(x, y, dst, ncols, eps); break; + case 64: fused_rms_norm_f32< 64><<>>(x, y, dst, ncols, eps); break; + case 96: fused_rms_norm_f32< 96><<>>(x, y, dst, ncols, eps); break; + case 128: fused_rms_norm_f32<128><<>>(x, y, dst, ncols, eps); break; + case 160: fused_rms_norm_f32<160><<>>(x, y, dst, ncols, eps); break; + case 192: fused_rms_norm_f32<192><<>>(x, y, dst, ncols, eps); break; + default : fused_rms_norm_f32<224><<>>(x, y, dst, ncols, eps); break; + } + } + else if (ncols < 1024) { + const dim3 block_dims(kBlockSize, 1, 1); + fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); fused_rms_norm_f32<1024><<>>(x, y, dst, ncols, eps); @@ -319,6 +339,16 @@ static void fused_rms_norm_f32_nc_cuda( if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //constexpr int kBlockSize = 256; + + //if (nchannels%4 == 0) { + // const dim3 blocks_num(nrows, nchannels/4, nsamples); + // const dim3 block_dims(kBlockSize, 4, 1); + // fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //} else { + // const dim3 block_dims(kBlockSize, 1, 1); + // fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + //} } else { const dim3 block_dims(1024, 1, 1); fused_rms_norm_f32_nc<1024><<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);