mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fused norm (#1086)
* Adding fused_norm - same idea as fused_rms_norm * Avoid computing the attention reduce op for cohere2 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -691,6 +691,7 @@ extern "C" {
|
||||
|
||||
GGML_OP_REDUCE,
|
||||
GGML_OP_FAKE_CPY,
|
||||
GGML_OP_FUSED_NORM,
|
||||
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
@@ -1487,6 +1488,18 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fused_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fused_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
// group normalize along ne0*ne1*n_groups
|
||||
// used in stable-diffusion
|
||||
GGML_API struct ggml_tensor * ggml_group_norm(
|
||||
|
||||
@@ -3208,6 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_fused_rms_norm(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FUSED_NORM:
|
||||
ggml_cuda_op_fused_rms_norm(ctx, dst, true);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
|
||||
@@ -4166,6 +4169,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_ROPE_FAST:
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
return true;
|
||||
case GGML_OP_FUSED_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
//case GGML_OP_ROPE:
|
||||
// return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_IM2COL:
|
||||
|
||||
@@ -36,6 +36,42 @@ static __global__ void norm_f32(const T * x, float * dst, const int ncols, const
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size, typename T>
|
||||
static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = mean_var;
|
||||
}
|
||||
__syncthreads();
|
||||
mean_var = s_sum[lane_id];
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
}
|
||||
|
||||
const float mean = mean_var.x / ncols;
|
||||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std * c[col]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||
// blockIdx.x: num_groups idx
|
||||
@@ -310,26 +346,47 @@ static void rms_norm_f32_nc_cuda(
|
||||
|
||||
template <typename src_t>
|
||||
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
|
||||
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < kBlockSize) {
|
||||
switch (ncols) {
|
||||
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
if (is_norm) {
|
||||
if (ncols < kBlockSize) {
|
||||
switch (ncols) {
|
||||
case 32: fused_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 64: fused_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 96: fused_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 128: fused_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 160: fused_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 192: fused_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
default : fused_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
}
|
||||
}
|
||||
else if (ncols < 1024) {
|
||||
const dim3 block_dims(kBlockSize, 1, 1);
|
||||
fused_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
else if (ncols < 1024) {
|
||||
const dim3 block_dims(kBlockSize, 1, 1);
|
||||
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
if (ncols < kBlockSize) {
|
||||
switch (ncols) {
|
||||
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
|
||||
}
|
||||
}
|
||||
else if (ncols < 1024) {
|
||||
const dim3 block_dims(kBlockSize, 1, 1);
|
||||
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,7 +484,7 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) {
|
||||
if (!dst->src[1]) {
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
return;
|
||||
@@ -453,11 +510,14 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
}
|
||||
} else {
|
||||
if (is_norm) {
|
||||
GGML_ABORT("Non-contiguous norm is not implemented");
|
||||
}
|
||||
auto ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(src0->nb[0] == ts0);
|
||||
auto s01 = src0->nb[1] / ts0;
|
||||
|
||||
@@ -6,7 +6,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false);
|
||||
|
||||
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);
|
||||
|
||||
|
||||
@@ -54,6 +54,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(nhave >=2 && nhave <= nreduce);
|
||||
if (dst->op_params[3] == 1) {
|
||||
// The dst tensor is just a container for the sources and the reduce op is turned off
|
||||
return;
|
||||
}
|
||||
|
||||
auto & info = ggml_cuda_info();
|
||||
#ifdef GGML_USE_NCCL
|
||||
|
||||
155
ggml/src/ggml.c
155
ggml/src/ggml.c
@@ -4294,9 +4294,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
|
||||
"REDUCE",
|
||||
"FAKE_CPY",
|
||||
"FUSED_NORM",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4405,9 +4406,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
|
||||
"reduce(x1,x2,...)",
|
||||
"fake_cpy(x,y)",
|
||||
"norm(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -7406,6 +7408,67 @@ struct ggml_tensor * ggml_fused_rms_norm_inplace(
|
||||
return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_fused_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps,
|
||||
bool inplace) {
|
||||
|
||||
if (!b) {
|
||||
return ggml_norm_impl(ctx, a, eps, inplace);
|
||||
}
|
||||
|
||||
if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
|
||||
struct ggml_tensor * result = ggml_norm_impl(ctx, a, eps, inplace);
|
||||
result = ggml_mul_impl(ctx, result, b, inplace);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result;
|
||||
if (inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
result = ggml_view_tensor(ctx, a);
|
||||
} else {
|
||||
if (a->type == GGML_TYPE_F32) {
|
||||
result = ggml_dup_tensor(ctx, a);
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||
|
||||
result->op = GGML_OP_FUSED_NORM;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fused_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps) {
|
||||
return ggml_fused_norm_impl(ctx, a, b, eps, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fused_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float eps) {
|
||||
return ggml_fused_norm_impl(ctx, a, b, eps, true);
|
||||
}
|
||||
|
||||
// ggml_rms_norm_back
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm_back(
|
||||
@@ -15404,6 +15467,88 @@ static void ggml_compute_forward_norm(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fused_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (!src1) {
|
||||
ggml_compute_forward_norm_f32(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
|
||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps > 0.0f);
|
||||
|
||||
const float * c = (const float *)src1->data;
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
ggml_float xi = (ggml_float)x[i00];
|
||||
sum += xi;
|
||||
}
|
||||
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_float sum2 = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
float v = x[i00] - mean;
|
||||
y[i00] = v * c[i00];
|
||||
sum2 += (ggml_float)(v*v);
|
||||
}
|
||||
|
||||
float variance = sum2/ne00;
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fused_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fused_norm_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_group_rms_norm
|
||||
|
||||
static void ggml_compute_forward_rms_norm_f32(
|
||||
@@ -22853,6 +22998,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_fused_rms_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FUSED_NORM:
|
||||
{
|
||||
ggml_compute_forward_fused_norm(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_back(params, tensor);
|
||||
@@ -23657,6 +23806,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_FUSED_NORM:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
@@ -24817,6 +24967,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_FUSED_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
|
||||
@@ -678,9 +678,10 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
GGML_ASSERT(norm->splits[id]);
|
||||
if (is_norm) {
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
//cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
//GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
//cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps);
|
||||
} else {
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
@@ -700,6 +701,13 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
|
||||
}
|
||||
if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
|
||||
// When the reduce op is turned off via op_params[3] == 1, we need to add each src
|
||||
// rtaher than add the reduced add_extra result to the ffn reduced ffn result.
|
||||
GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different
|
||||
cur = ggml_add(ctx, cur, add_extra->src[id]);
|
||||
cb(cur, "ffn_with_extra", il_cb);
|
||||
}
|
||||
if (graph) {
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
}
|
||||
@@ -711,7 +719,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
ffn[id_last] = ggml_add(ctx, ffn[id_last], input);
|
||||
cb(ffn[id_last], "ffn_with_inp", il);
|
||||
}
|
||||
if (add_extra) {
|
||||
if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) {
|
||||
ffn[id_last] = ggml_add(ctx, ffn[id_last], add_extra);
|
||||
cb(ffn[id_last], "ffn_with_inp", il);
|
||||
}
|
||||
@@ -7287,6 +7295,8 @@ ggml_cgraph * llm_build_context::build_cohere2() {
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
attn_out->op_params[3] = 1; // i.e., turn off the reduce operation as it is not required
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
|
||||
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
|
||||
@@ -9379,9 +9389,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (attn_norm) {
|
||||
if (is_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
//cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
//GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
//cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps);
|
||||
} else {
|
||||
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user