Also for matrix x vector

This commit is contained in:
Iwan Kawrakow
2025-05-22 14:05:35 +03:00
parent 11e674af0e
commit 4446e8ca5c

View File

@@ -2521,6 +2521,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
bool fuse_down = false;
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
@@ -2562,14 +2563,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return true;
fuse_down = true;
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
CUDA_CHECK(cudaGetLastError());
return false;
}
CUDA_CHECK(cudaStreamSynchronize(stream));
return fuse_down;
}
}