mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
WIP
This commit is contained in:
@@ -2718,13 +2718,34 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
|
||||
dst_row.data = dst_up_contiguous.get();
|
||||
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
|
||||
if (dst->src[4]) {
|
||||
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
|
||||
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
|
||||
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
dst_row.data = dst_gate_contiguous.get();
|
||||
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
|
||||
if (dst->src[5]) {
|
||||
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data,
|
||||
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
|
||||
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst->data, ggml_nelements(dst), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
|
||||
1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst->data);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
|
||||
|
||||
Reference in New Issue
Block a user