Fuse Q, K, V gemv+add

This commit is contained in:
Iwan Kawrakow
2025-10-25 17:30:45 +03:00
parent f76e98536f
commit 5ddde01542
5 changed files with 111 additions and 7 deletions

View File

@@ -2064,8 +2064,26 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const ggml_cgraph * cgraph, int node_n, bool is_gemv) {
//if (cgraph && node_n + 6 < cgraph->n_nodes) {
// printf("=== %s\n", __func__);
// for (int i = 0; i <= 6; ++i) printf("%d: %s(%s)\n", i, ggml_op_name(cgraph->nodes[node_n+i]->op), cgraph->nodes[node_n+i]->name);
//}
auto stream = ctx.stream();
//if (cgraph && node_n + 5 < cgraph->n_nodes &&
// cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT &&
// cgraph->nodes[node_n+2]->op == GGML_OP_MUL_MAT &&
// cgraph->nodes[node_n+3]->op == GGML_OP_ADD &&
// cgraph->nodes[node_n+4]->op == GGML_OP_ADD &&
// cgraph->nodes[node_n+5]->op == GGML_OP_ADD &&
// cgraph->nodes[node_n+0] == cgraph->nodes[node_n+3]->src[0] &&
// cgraph->nodes[node_n+1] == cgraph->nodes[node_n+4]->src[0] &&
// cgraph->nodes[node_n+2] == cgraph->nodes[node_n+5]->src[0]) {
// printf("Could process mulmat(%s) + mulmat(%s) + mulmat(%s) + add(%s) + add(%s) + add(%s)\n",
// cgraph->nodes[node_n+0]->name, cgraph->nodes[node_n+1]->name, cgraph->nodes[node_n+2]->name,
// cgraph->nodes[node_n+3]->name, cgraph->nodes[node_n+4]->name, cgraph->nodes[node_n+5]->name);
//}
auto ne10_padded = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
auto nb10_padded = ne10_padded*sizeof(block_q8_1)/QK8_1;
auto quantized_size = nb10_padded*ggml_nrows(src1);
@@ -2078,9 +2096,35 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
src0->type, stream);
CUDA_CHECK(cudaGetLastError());
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
if (cgraph && node_n + 5 < cgraph->n_nodes &&
cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT &&
cgraph->nodes[node_n+2]->op == GGML_OP_MUL_MAT &&
ggml_is_quantized(cgraph->nodes[node_n+1]->src[0]->type) &&
ggml_is_quantized(cgraph->nodes[node_n+2]->src[0]->type) &&
cgraph->nodes[node_n+3]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+4]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+5]->op == GGML_OP_ADD &&
cgraph->nodes[node_n+0] == cgraph->nodes[node_n+3]->src[0] &&
cgraph->nodes[node_n+1] == cgraph->nodes[node_n+4]->src[0] &&
cgraph->nodes[node_n+2] == cgraph->nodes[node_n+5]->src[0]) {
//printf("Processing mulmat(%s) + mulmat(%s) + mulmat(%s) + add(%s) + add(%s) + add(%s)\n",
// cgraph->nodes[node_n+0]->name, cgraph->nodes[node_n+1]->name, cgraph->nodes[node_n+2]->name,
// cgraph->nodes[node_n+3]->name, cgraph->nodes[node_n+4]->name, cgraph->nodes[node_n+5]->name);
for (int i = 0; i < 3; ++i) {
auto src0_i = cgraph->nodes[node_n+i]->src[0];
//printf(" using %s(%s) with %s, %s\n", ggml_op_name(cgraph->nodes[node_n+i+3]->op), cgraph->nodes[node_n+i+3]->name,
// cgraph->nodes[node_n+i+3]->src[0]->name, cgraph->nodes[node_n+i+3]->src[1]->name);
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0_i, src1, cgraph->nodes[node_n+i], cgraph->nodes[node_n+i+3]->src[1],
(const char *)src0_i->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+i]->data,
0, src0_i->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
}
node_n += 5;
} else {
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
}
} else {
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
CUDA_CHECK(cudaGetLastError());

View File

@@ -313,7 +313,25 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, const float * y, float * z) {
int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= nelem) {
return;
}
z[i] = x[i] + y[i % ne0];
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] &&
dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
ggml_are_same_shape(dst, dst->src[0]) && ggml_is_contiguous(dst)) {
constexpr int kBlockSize = 256;
auto nelem = ggml_nelements(dst);
int nblocks = (nelem + kBlockSize - 1)/kBlockSize;
k_fast_add<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
return;
}
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

View File

@@ -167,10 +167,18 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
GGML_UNUSED(src1_ddf_i);
}
//static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
// const int64_t ne00, const int64_t ne0, const int64_t ne2,
// const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0, const int64_t bias_nb1,
// const char * src0_dd_u, const char * src0_dd_g, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data,
// const void * bias_u, const void * bias_g,
// const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
// const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream) {
void ggml_cuda_op_mul_mat_vec_q(
void ggml_cuda_op_mul_mat_vec_q_biased(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
@@ -180,14 +188,37 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t ne0 = dst->ne[0];
if (bias) {
if (bias->ne[0] != ne0) {
printf("Oops: bias %s is %ld x %ld x %ld x %ld, dst %s is %ld x %ld x %ld x %ld\n",
bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3],
dst->name, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
GGML_ASSERT(bias->ne[0] == ne0);
GGML_ASSERT(bias->type == GGML_TYPE_F32);
if (ggml_nrows(bias) != 1) {
printf("Oops: bias %s is %ld x %ld x %ld x %ld\n", bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3]);
}
GGML_ASSERT(ggml_nrows(bias) == 1);
}
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
ne00, ne0, 1, 0, 0, 0, 0, 0,
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, nullptr, nullptr,
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, bias ? bias->data : nullptr, nullptr,
row_low, row_high, src1_ncols,
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
GGML_UNUSED(src1_ddf_i);
}
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0, src1, dst, nullptr, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols,
src1_padded_row_size, stream);
}
void ggml_cuda_op_mul_mat_vec_q_id(
ggml_backend_cuda_context & ctx,

View File

@@ -9,12 +9,20 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_op_mul_mat_vec_q_biased(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,

View File

@@ -1240,14 +1240,17 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
if (bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
ggml_build_forward_expand(gf, Qcur);
}
if (bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
ggml_build_forward_expand(gf, Kcur);
}
if (bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Vcur);
}
return {Qcur, Kcur, Vcur};
}