mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
cuda: use better block sizes for rms_norm
This commit is contained in:
@@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
//const int channel = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
@@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
//if constexpr (block_size == 1024) {
|
||||
// tmp = s_sum[lane_id];
|
||||
//} else {
|
||||
// tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
|
||||
//}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
@@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
constexpr int kBlockSize = 256;
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
const dim3 block_dims(kBlockSize, 1, 1);
|
||||
rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
@@ -290,6 +297,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||
static void rms_norm_f32_nc_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
printf("%s: ncols = %d\n", __func__, ncols);
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
@@ -302,10 +310,22 @@ static void rms_norm_f32_nc_cuda(
|
||||
|
||||
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
|
||||
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
if (ncols < kBlockSize) {
|
||||
switch (ncols) {
|
||||
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
}
|
||||
}
|
||||
else if (ncols < 1024) {
|
||||
const dim3 block_dims(kBlockSize, 1, 1);
|
||||
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
@@ -319,6 +339,16 @@ static void fused_rms_norm_f32_nc_cuda(
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
//constexpr int kBlockSize = 256;
|
||||
|
||||
//if (nchannels%4 == 0) {
|
||||
// const dim3 blocks_num(nrows, nchannels/4, nsamples);
|
||||
// const dim3 block_dims(kBlockSize, 4, 1);
|
||||
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
//} else {
|
||||
// const dim3 block_dims(kBlockSize, 1, 1);
|
||||
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
//}
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
|
||||
Reference in New Issue
Block a user