Even more fused ops (#868)

* Fuse Q, K, V gemv+add

* More gemv+add fusing

* Faster copy when tensors are contiguous

Relevant for storing data into the KV cache. I see ~1% speedup
for fast models (Ling-mini-2.0, gpt-oss-20b, etc.)

* Cleanup

* Make sure the bias really is 1 row to use fusion

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-27 16:09:01 +02:00
committed by GitHub
parent bf991ba60a
commit eb8116b097
6 changed files with 159 additions and 15 deletions

View File

@@ -2078,9 +2078,43 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
src0->type, stream); src0->type, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, // The code below handles the case when Q, K, V have a bias applied after the resepctive matrix multiplication.
0, src0->ne[1], src1->ne[1], ne10_padded, stream); // In that case the graph contains mul_mat(Q) -> mul_mat(K) -> mul_mat(V) -> add(Q) -> add(K) -> add(V)
CUDA_CHECK(cudaGetLastError()); if (cgraph && node_n + 5 < cgraph->n_nodes &&
cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT &&
cgraph->nodes[node_n+2]->op == GGML_OP_MUL_MAT &&
ggml_is_quantized(cgraph->nodes[node_n+1]->src[0]->type) &&
ggml_is_quantized(cgraph->nodes[node_n+2]->src[0]->type) &&
cgraph->nodes[node_n+3]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+4]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+5]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+0] == cgraph->nodes[node_n+3]->src[0] &&
cgraph->nodes[node_n+1] == cgraph->nodes[node_n+4]->src[0] &&
cgraph->nodes[node_n+2] == cgraph->nodes[node_n+5]->src[0]) {
for (int i = 0; i < 3; ++i) {
auto src0_i = cgraph->nodes[node_n+i]->src[0];
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0_i, src1, cgraph->nodes[node_n+i], cgraph->nodes[node_n+i+3]->src[1],
(const char *)src0_i->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+i]->data,
0, src0_i->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
}
node_n += 5;
} else if (cgraph && node_n + 1 < cgraph->n_nodes &&
cgraph->nodes[node_n+1]->op == GGML_OP_ADD &&
dst == cgraph->nodes[node_n+1]->src[0] &&
dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] &&
cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 &&
ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) {
// We have a bias applied after the matrix multiplication and we can fuse it
ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1],
(const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data,
0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
++node_n;
} else {
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
}
} else { } else {
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream); quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
@@ -2101,8 +2135,21 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break; if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break;
if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break; if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break;
if (is_gemv) { if (is_gemv) {
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), if (node_n + 1 < cgraph->n_nodes &&
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); cgraph->nodes[node_n+1]->op == GGML_OP_ADD &&
dst == cgraph->nodes[node_n+1]->src[0] &&
dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] &&
cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 &&
ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) {
// We have a bias applied after the matrix multiplication and we can fuse it
ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1],
(const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data,
0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
++node_n;
} else {
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
}
} else { } else {
ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);

View File

@@ -313,7 +313,25 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
} }
static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, const float * y, float * z) {
int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= nelem) {
return;
}
z[i] = x[i] + y[i % ne0];
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] &&
dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
ggml_are_same_shape(dst, dst->src[0]) && ggml_is_contiguous(dst)) {
constexpr int kBlockSize = 256;
auto nelem = ggml_nelements(dst);
int nblocks = (nelem + kBlockSize - 1)/kBlockSize;
k_fast_add<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
return;
}
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
} }

View File

@@ -38,6 +38,25 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
cpy_1(cx + x_offset, cdst + dst_offset); cpy_1(cx + x_offset, cdst + dst_offset);
} }
template <typename src_t, typename dst_t>
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst_direct, const int ne,
char ** cdst_indirect, int graph_cpynode_index) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
auto dst = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index] : (dst_t *)cdst_direct;
auto src = (const src_t *)cx;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
dst[i] = __float2bfloat16(src[i]);
} else {
dst[i] = (dst_t)src[i];
}
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti); float * cdstf = (float *)(cdsti);
@@ -163,6 +182,16 @@ static void ggml_cpy_flt_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
template<typename src_t, typename dst_t>
static void ggml_cpy_flt_contiguous_cuda(
const char * cx, char * cdst, const int ne,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, cdst_indirect, graph_cpynode_index++);
}
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -404,6 +433,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data; char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1);
char ** dest_ptrs_d = nullptr; char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1; int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
@@ -429,11 +460,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} }
} }
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); if (fast_cpy) {
ggml_cpy_flt_contiguous_cuda<float, float>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); if (fast_cpy) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); if (fast_cpy) {
ggml_cpy_flt_contiguous_cuda<float, half>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -505,6 +548,7 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} }
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1);
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
// Prioritize CUDA graph compatibility over direct memory copy optimization. // Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
@@ -514,11 +558,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return nullptr; return nullptr;
} }
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>; return fast_cpy ? (void *)cpy_flt_contiguous<float, float> : (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>; return fast_cpy ? (void *)cpy_flt_contiguous<float, nv_bfloat16> : (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<float, half>>; return fast_cpy ? (void *)cpy_flt_contiguous<float, half> : (void*) cpy_flt<cpy_1_flt<float, half>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>; return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {

View File

@@ -168,9 +168,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
GGML_UNUSED(src1_ddf_i); GGML_UNUSED(src1_ddf_i);
} }
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q_biased(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t src1_padded_row_size, cudaStream_t stream) {
@@ -180,14 +181,37 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
if (bias) {
if (bias->ne[0] != ne0) {
printf("Oops: bias %s is %ld x %ld x %ld x %ld, dst %s is %ld x %ld x %ld x %ld\n",
bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3],
dst->name, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
GGML_ASSERT(bias->ne[0] == ne0);
GGML_ASSERT(bias->type == GGML_TYPE_F32);
if (ggml_nrows(bias) != 1) {
printf("Oops: bias %s is %ld x %ld x %ld x %ld\n", bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3]);
}
GGML_ASSERT(ggml_nrows(bias) == 1);
}
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
ne00, ne0, 1, 0, 0, 0, 0, 0, ne00, ne0, 1, 0, 0, 0, 0, 0,
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, nullptr, nullptr, src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, bias ? bias->data : nullptr, nullptr,
row_low, row_high, src1_ncols, row_low, row_high, src1_ncols,
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream); src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
GGML_UNUSED(src1_ddf_i); GGML_UNUSED(src1_ddf_i);
} }
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0, src1, dst, nullptr, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols,
src1_padded_row_size, stream);
}
void ggml_cuda_op_mul_mat_vec_q_id( void ggml_cuda_op_mul_mat_vec_q_id(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,

View File

@@ -9,12 +9,20 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_op_mul_mat_vec_q_biased(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream); const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type); bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,

View File

@@ -1240,14 +1240,17 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
if (bq) { if (bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
ggml_build_forward_expand(gf, Qcur);
} }
if (bk) { if (bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
ggml_build_forward_expand(gf, Kcur);
} }
if (bv) { if (bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Vcur);
} }
return {Qcur, Kcur, Vcur}; return {Qcur, Kcur, Vcur};
} }