Fused norm (#1086)

* Adding fused_norm - same idea as fused_rms_norm

* Avoid computing the attention reduce op for cohere2

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-12-24 15:22:43 +01:00
committed by GitHub
parent 5e64235d4c
commit ada5cc1523
7 changed files with 273 additions and 29 deletions

View File

@@ -691,6 +691,7 @@ extern "C" {
GGML_OP_REDUCE,
GGML_OP_FAKE_CPY,
GGML_OP_FUSED_NORM,
GGML_OP_COUNT,
};
@@ -1487,6 +1488,18 @@ extern "C" {
struct ggml_tensor * b,
float eps);
GGML_API struct ggml_tensor * ggml_fused_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float eps);
GGML_API struct ggml_tensor * ggml_fused_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float eps);
// group normalize along ne0*ne1*n_groups
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_group_norm(

View File

@@ -3208,6 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_fused_rms_norm(ctx, dst);
}
break;
case GGML_OP_FUSED_NORM:
ggml_cuda_op_fused_rms_norm(ctx, dst, true);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -4166,6 +4169,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ROPE_FAST:
case GGML_OP_ROPE_CACHE:
return true;
case GGML_OP_FUSED_NORM:
return ggml_is_contiguous(op->src[0]);
//case GGML_OP_ROPE:
// return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:

View File

@@ -36,6 +36,42 @@ static __global__ void norm_f32(const T * x, float * dst, const int ncols, const
}
}
template <int block_size, typename T>
static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = (float)x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std * c[col]);
}
}
template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
@@ -310,26 +346,47 @@ static void rms_norm_f32_nc_cuda(
template <typename src_t>
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) {
constexpr int kBlockSize = 256;
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
if (is_norm) {
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
if (ncols < kBlockSize) {
switch (ncols) {
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
}
}
else if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
}
}
}
@@ -427,7 +484,7 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) {
if (!dst->src[1]) {
ggml_cuda_op_rms_norm(ctx, dst);
return;
@@ -453,11 +510,14 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
if (src0->type == GGML_TYPE_F32) {
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
} else {
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
}
} else {
if (is_norm) {
GGML_ABORT("Non-contiguous norm is not implemented");
}
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;

View File

@@ -6,7 +6,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false);
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);

View File

@@ -54,6 +54,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(nhave >=2 && nhave <= nreduce);
if (dst->op_params[3] == 1) {
// The dst tensor is just a container for the sources and the reduce op is turned off
return;
}
auto & info = ggml_cuda_info();
#ifdef GGML_USE_NCCL

View File

@@ -4294,9 +4294,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"REDUCE",
"FAKE_CPY",
"FUSED_NORM",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -4405,9 +4406,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"reduce(x1,x2,...)",
"fake_cpy(x,y)",
"norm(x,y)",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -7406,6 +7408,67 @@ struct ggml_tensor * ggml_fused_rms_norm_inplace(
return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
}
static struct ggml_tensor * ggml_fused_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float eps,
bool inplace) {
if (!b) {
return ggml_norm_impl(ctx, a, eps, inplace);
}
if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
struct ggml_tensor * result = ggml_norm_impl(ctx, a, eps, inplace);
result = ggml_mul_impl(ctx, result, b, inplace);
return result;
}
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result;
if (inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
result = ggml_view_tensor(ctx, a);
} else {
if (a->type == GGML_TYPE_F32) {
result = ggml_dup_tensor(ctx, a);
} else {
result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
}
}
ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_FUSED_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_tensor * ggml_fused_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float eps) {
return ggml_fused_norm_impl(ctx, a, b, eps, false);
}
struct ggml_tensor * ggml_fused_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float eps) {
return ggml_fused_norm_impl(ctx, a, b, eps, true);
}
// ggml_rms_norm_back
struct ggml_tensor * ggml_rms_norm_back(
@@ -15404,6 +15467,88 @@ static void ggml_compute_forward_norm(
}
}
static void ggml_compute_forward_fused_norm_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
if (!src1) {
ggml_compute_forward_norm_f32(params, dst);
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
const float * c = (const float *)src1->data;
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
ggml_float xi = (ggml_float)x[i00];
sum += xi;
}
const float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_float sum2 = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v * c[i00];
sum2 += (ggml_float)(v*v);
}
float variance = sum2/ne00;
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_compute_forward_fused_norm(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fused_norm_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_group_rms_norm
static void ggml_compute_forward_rms_norm_f32(
@@ -22853,6 +22998,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_fused_rms_norm(params, tensor);
} break;
case GGML_OP_FUSED_NORM:
{
ggml_compute_forward_fused_norm(params, tensor);
} break;
case GGML_OP_RMS_NORM_BACK:
{
ggml_compute_forward_rms_norm_back(params, tensor);
@@ -23657,6 +23806,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
}
} break;
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_NORM:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
@@ -24817,6 +24967,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT: