mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Also for matrix x vector
This commit is contained in:
@@ -2521,6 +2521,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
bool fuse_down = false;
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
||||
@@ -2562,14 +2563,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return true;
|
||||
fuse_down = true;
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return false;
|
||||
}
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
return fuse_down;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user