This commit is contained in:
Iwan Kawrakow
2025-12-13 14:52:23 +00:00
parent 81fc5e3f08
commit d11def5ceb
2 changed files with 75 additions and 14 deletions

View File

@@ -321,6 +321,15 @@ static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, c
z[i] = x[i] + y[i % ne0];
}
template <typename src1_t, typename src2_t, typename dst_t>
static __global__ void k_fast_add_2(int64_t ne0, int64_t nelem, const src1_t * x, const src2_t * y, dst_t * z) {
int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= nelem) {
return;
}
z[i] = (dst_t)((float)x[i] + (float)y[i]);
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] &&
dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
@@ -332,6 +341,45 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
return;
}
if (ggml_is_contiguous(dst->src[0]) && ggml_are_same_shape(dst->src[0], dst->src[1]) && ggml_is_contiguous(dst)) {
constexpr int kBlockSize = 256;
auto nelem = ggml_nelements(dst);
int nblocks = (nelem + kBlockSize - 1)/kBlockSize;
if (dst->type == GGML_TYPE_F16) {
if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F16) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const half *)dst->src[0]->data, (const half *)dst->src[1]->data, (half *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const half *)dst->src[0]->data, (const float *)dst->src[1]->data, (half *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (half *)dst->data);
} else {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (half *)dst->data);
}
} else {
if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F16) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const half *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const half *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
}
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
} else {
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);
}
}
return;
}
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

View File

@@ -693,9 +693,9 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (ffn.size() > 2) {
cur->op_params[0] = 0xff;
}
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
}
//if (cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
//}
return cur;
}
@@ -7246,24 +7246,21 @@ ggml_cgraph * llm_build_context::build_cohere2() {
ggml_backend_sched_set_tensor_backend(lctx.sched, cur->src[0], ggml_backend_sched_get_backend(lctx.sched, id));
}
cb(cur, "attn_norm", il);
struct ggml_tensor * ffn_inp = cur;
auto ffn_inp = cur;
// self-attention
cur = build_std_attention(gf, nullptr, cur, inp_pos, nullptr, KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f,
auto attn_out = build_std_attention(gf, nullptr, cur, inp_pos, nullptr, KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f,
is_sliding ? hparams.n_swa : 0, il, is_sliding, true);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "attn_out", il);
cb(attn_out, "attn_out", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
attn_out = ggml_get_rows(ctx0, attn_out, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
auto attn_out = cur;
// feed-forward network
cur = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
@@ -7272,6 +7269,7 @@ ggml_cgraph * llm_build_context::build_cohere2() {
// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_add(ctx0, cur, inpL);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
@@ -7280,9 +7278,9 @@ ggml_cgraph * llm_build_context::build_cohere2() {
}
cur = inpL;
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
}
//if (cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
//}
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
@@ -9491,6 +9489,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
}
GGML_ASSERT(!attn.empty());
if (attn.size() == 1) return attn.front();
//if (attn.size() > 2 && attn.size()%2 == 0) {
// for (int id = 0; id < int(attn.size()/2); ++id) {
// attn[id] = ggml_add(ctx0, attn[2*id+0], attn[2*id+1]);
// attn[id]->op_params[0] = 0xff;
// }
// attn.resize(attn.size()/2);
// auto cur = ggml_add(ctx0, attn[0], attn[1]);
// cur->op_params[0] = 0xff;
// cur->op_params[0] = 0xff;
// for (int id = 2; id < (int)attn.size(); ++id) {
// cur = ggml_add(ctx0, cur, attn[id]);
// cb(cur, "combine_attn", il);
// }
// return cur;
//}
auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
cur->op_params[0] = 0xff;