cuda: use better block sizes for rms_norm

This commit is contained in:
Iwan Kawrakow
2025-10-20 16:12:13 +03:00
parent 5ae87f6cdf
commit 0a9752db6c

View File

@@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc(
const int row = blockIdx.x;
const int channel = blockIdx.y;
//const int channel = blockIdx.y * blockDim.y + threadIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
@@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc(
}
__syncthreads();
tmp = s_sum[lane_id];
//if constexpr (block_size == 1024) {
// tmp = s_sum[lane_id];
//} else {
// tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
//}
tmp = warp_reduce_sum(tmp);
}
@@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
constexpr int kBlockSize = 256;
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
const dim3 block_dims(kBlockSize, 1, 1);
rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -290,6 +297,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
static void rms_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
printf("%s: ncols = %d\n", __func__, ncols);
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -302,10 +310,22 @@ static void rms_norm_f32_nc_cuda(
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
constexpr int kBlockSize = 256;
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
@@ -319,6 +339,16 @@ static void fused_rms_norm_f32_nc_cuda(
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//constexpr int kBlockSize = 256;
//if (nchannels%4 == 0) {
// const dim3 blocks_num(nrows, nchannels/4, nsamples);
// const dim3 block_dims(kBlockSize, 4, 1);
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//} else {
// const dim3 block_dims(kBlockSize, 1, 1);
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
//}
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);