mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-03 13:04:59 +00:00
Adding fused rms_norm (#42)
* Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -480,6 +480,7 @@ extern "C" {
|
||||
GGML_OP_RMS_NORM,
|
||||
GGML_OP_RMS_NORM_BACK,
|
||||
GGML_OP_GROUP_NORM,
|
||||
GGML_OP_FUSED_RMS_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
@@ -1159,6 +1160,18 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fused_rms_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
// group normalize along ne0*ne1*n_groups
|
||||
// used in stable-diffusion
|
||||
GGML_API struct ggml_tensor * ggml_group_norm(
|
||||
|
||||
@@ -2248,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
ggml_cuda_op_fused_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
|
||||
@@ -2871,6 +2874,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SOFTCAP:
|
||||
case GGML_OP_SQR:
|
||||
|
||||
@@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
@@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
|
||||
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if (!dst->src[1]) {
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
return;
|
||||
}
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -104,6 +104,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
@@ -613,6 +614,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, get_rows_iq6_k, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, fused_rms_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -884,6 +886,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ctx->support_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
@@ -2606,6 +2609,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < 1024) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:5];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
||||
@@ -1038,6 +1038,57 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_fused_rms_norm(
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant float & eps,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
sumf += x[i00] * x[i00];
|
||||
}
|
||||
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
||||
all_sum = simd_sum(all_sum);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = all_sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
all_sum = buf[tiisg];
|
||||
all_sum = simd_sum(all_sum);
|
||||
}
|
||||
|
||||
const float mean = all_sum/ne00;
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
||||
device float4 * z = (device float4 *)src1;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
y[i00] = x[i00] * z[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
138
ggml/src/ggml.c
138
ggml/src/ggml.c
@@ -3144,6 +3144,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RMS_NORM",
|
||||
"RMS_NORM_BACK",
|
||||
"GROUP_NORM",
|
||||
"FUSED_RMS_NORM",
|
||||
|
||||
"MUL_MAT",
|
||||
"MUL_MAT_ID",
|
||||
@@ -3207,7 +3208,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
||||
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -3234,6 +3235,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rms_norm(x)",
|
||||
"rms_norm_back(x)",
|
||||
"group_norm(x)",
|
||||
"fused_rms_norm(x)",
|
||||
|
||||
"X*Y",
|
||||
"X[i]*Y",
|
||||
@@ -3297,7 +3299,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
||||
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -5737,6 +5739,57 @@ struct ggml_tensor * ggml_rms_norm_inplace(
|
||||
return ggml_rms_norm_impl(ctx, a, eps, true);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_fused_rms_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps,
|
||||
bool inplace) {
|
||||
|
||||
if (!b) {
|
||||
return ggml_rms_norm_impl(ctx, a, eps, inplace);
|
||||
}
|
||||
|
||||
if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
|
||||
struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace);
|
||||
result = ggml_mul_impl(ctx, result, b, inplace);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||
|
||||
result->op = GGML_OP_FUSED_RMS_NORM;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fused_rms_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps) {
|
||||
return ggml_fused_rms_norm_impl(ctx, a, b, eps, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fused_rms_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps) {
|
||||
return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
|
||||
}
|
||||
|
||||
// ggml_rms_norm_back
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm_back(
|
||||
@@ -12455,6 +12508,78 @@ static void ggml_compute_forward_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fused_rms_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (!src1) {
|
||||
ggml_compute_forward_rms_norm_f32(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
|
||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps > 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fused_rms_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fused_rms_norm_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rms_norm_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
@@ -17708,6 +17833,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rms_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
{
|
||||
ggml_compute_forward_fused_rms_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_back(params, tensor);
|
||||
@@ -18398,6 +18527,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
@@ -19465,6 +19598,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
|
||||
@@ -7987,6 +7987,16 @@ static struct ggml_tensor * llm_build_norm(
|
||||
llm_norm_type type,
|
||||
const llm_build_cb & cb,
|
||||
int il, float scale_eps = 1) {
|
||||
|
||||
if (type == LLM_NORM_RMS && mw) {
|
||||
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
|
||||
if (mb) {
|
||||
cb(cur, "fused_norm", il);
|
||||
cur = ggml_add(ctx, cur, mb);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
|
||||
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, scale_eps * hparams.f_norm_rms_eps); break;
|
||||
|
||||
Reference in New Issue
Block a user