Fused matrix multiplications (CUDA and CPU) (#796)

* Quick attempt to fuse the Q, K, V GEMMs

Doesn't do much on the CPU

* Doesn't do much on the GPU either

* Use llm_build_mul_mat_qkv

* This is not needed

* Revert timing on committed by mistake

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-09-24 16:52:54 +02:00
committed by GitHub
parent 9c6988f61c
commit f8b66238fa
3 changed files with 244 additions and 564 deletions

View File

@@ -2143,7 +2143,62 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const ggml_cgraph * cgraph, int node_n, bool is_gemv) {
auto stream = ctx.stream();
auto ne10_padded = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
auto nb10_padded = ne10_padded*sizeof(block_q8_1)/QK8_1;
auto quantized_size = nb10_padded*ggml_nrows(src1);
if (!is_gemv) {
quantized_size += get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
}
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), quantized_size);
if (is_gemv) {
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], src1->ne[1], src1->ne[2], ne10_padded,
src0->type, stream);
CUDA_CHECK(cudaGetLastError());
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
} else {
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
CUDA_CHECK(cudaGetLastError());
ggml_cuda_op_mul_mat_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
CUDA_CHECK(cudaGetLastError());
}
if (!cgraph) return node_n;
while (node_n + 1 < cgraph->n_nodes) {
dst = cgraph->nodes[node_n+1];
if (ggml_is_empty(dst) || dst->op == GGML_OP_RESHAPE || dst->op == GGML_OP_TRANSPOSE || dst->op == GGML_OP_VIEW
|| dst->op == GGML_OP_PERMUTE || dst->op == GGML_OP_NONE) {
++node_n; continue;
}
if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break;
if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break;
if (is_gemv) {
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
} else {
ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
}
CUDA_CHECK(cudaGetLastError());
++node_n;
}
return node_n;
}
static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const ggml_cgraph * cgraph, int node_n) {
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
@@ -2188,6 +2243,10 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
}
if (!split && (use_mul_mat_vec_q || use_mul_mat_q) && src1->ne[2]*src1->ne[3] == 1) {
return ggml_cuda_mul_mat_q(ctx, src0, src1, dst, cgraph, node_n, use_mul_mat_vec_q);
}
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -2215,6 +2274,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
return node_n;
}
struct mmid_row_mapping {
@@ -2454,7 +2514,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0);
}
}
} else {
@@ -2505,7 +2565,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0);
{
dim3 block_dims(std::min((unsigned int)ne0, 768u));
@@ -2889,7 +2949,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
} else {
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row, nullptr, 0);
}
CUDA_CHECK(cudaGetLastError());
@@ -2906,7 +2966,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
} else {
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0);
}
CUDA_CHECK(cudaGetLastError());
@@ -2947,8 +3007,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
(int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]);
first = false;
}
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
//ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst);
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0);
CUDA_CHECK(cudaGetLastError());
dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));
@@ -3031,8 +3090,7 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next,
const ggml_cgraph * cgraph, int & i) {
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
// why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
@@ -3042,6 +3100,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
int64_t tim1 = ggml_time_us();
#endif
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
switch (dst->op) {
case GGML_OP_REPEAT:
ggml_cuda_op_repeat(ctx, dst);
@@ -3112,7 +3172,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_hardswish(ctx, dst);
break;
default:
return false;
return -1;
}
break;
case GGML_OP_NORM:
@@ -3148,9 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
return false;
return -1;
} else {
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
i = ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst, cgraph, i);
}
break;
case GGML_OP_MUL_MAT_ID:
@@ -3569,7 +3629,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
@@ -3604,7 +3663,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
GGML_UNUSED(integrated);
#endif // NDEBUG
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i);
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
if (!ok) {
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}

View File

@@ -14974,9 +14974,11 @@ static inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
return a;
}
static void ggml_compute_forward_mul_mat(
static int ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
struct ggml_tensor * dst,
const struct ggml_cgraph * cgraph,
int node_n) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
@@ -15017,12 +15019,6 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
#endif
#if GGML_USE_IQK_MULMAT
if (ith == 0) {
static bool first_time = true;
@@ -15040,34 +15036,10 @@ static void ggml_compute_forward_mul_mat(
ne02, ne03, ne12, ne13, nb02, nb03, nb12, nb13, nb2/sizeof(float), nb3/sizeof(float),
src0->type, src0->data, nb01,
src1->type, src1->data, nb11,
(float *)dst->data, nb1/sizeof(float), ith, nth)) return;
(float *)dst->data, nb1/sizeof(float), ith, nth)) return node_n;
}
#endif
#if GGML_USE_LLAMAFILE
const bool src1_cont = ggml_is_contiguous(src1);
if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
src1->type,
dst->type))
goto UseGgmlGemm1;
return;
}
UseGgmlGemm1:;
#endif
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
@@ -15092,51 +15064,27 @@ UseGgmlGemm1:;
}
else {
//#ifdef GGML_USE_IQK_MULMAT
// int ts = type_traits[vec_dot_type].type_size;
// int bs = type_traits[vec_dot_type].blck_size;
// int64_t blocks_per_row = ne10/bs;
// int64_t num_blocks = ne11*ne12*ne13*blocks_per_row;
// int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved
// // with trying to get it from ggml
// int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd;
// int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd;
// int64_t first_block = ith*block_per_thread;
// int64_t last_block = MIN(num_blocks, first_block + block_per_thread);
// while (first_block < last_block) {
// int64_t i13 = first_block/(ne11*ne12*blocks_per_row);
// int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row);
// int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row;
// int64_t i10 = first_block % blocks_per_row;
// int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block);
// from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs,
// (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs);
// first_block += blocks_to_do;
// }
//#else
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0;
#if !GGML_USE_IQK_MULMAT
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, blck_size_interleave);
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, blck_size_interleave);
}
i11_processed = ne11 - ne11 % 4;
}
i11_processed = ne11 - ne11 % 4;
}
#endif
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
}
}
}
}
//#endif
}
ggml_barrier(params->shared);
@@ -15145,17 +15093,10 @@ UseGgmlGemm1:;
if (ith == 0) printf("quantize(%s): %d us\n", dst->name, (int)(t2 - t1));
#endif
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
//atomic_store(&params->shared->current_chunk, nth);
}
ggml_barrier(params->shared);
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
#if GGML_USE_IQK_MULMAT
if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
if (iqk_mul_mat_4d(ne01, ne11, ne00,
@@ -15163,32 +15104,27 @@ UseGgmlGemm1:;
nb2/sizeof(float), nb3/sizeof(float),
src0->type, src0->data, nb01,
vec_dot_type, wdata, row_size,
(float *)dst->data, nb1/sizeof(float), ith, nth)) return;
(float *)dst->data, nb1/sizeof(float), ith, nth)) {
while (node_n < cgraph->n_nodes - 1 &&
cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT &&
cgraph->nodes[node_n+1]->src[1] == src1 &&
type_traits[cgraph->nodes[node_n+1]->src[0]->type].vec_dot_type == vec_dot_type) {
struct ggml_tensor * dst_next = cgraph->nodes[node_n+1];
struct ggml_tensor * src0_next = dst_next->src[0];
GGML_ASSERT(dst_next->type == GGML_TYPE_F32);
GGML_ASSERT(src0_next->ne[0] == ne00);
//if (ith == 0) printf("Fusing %s\n", src0_next->name);
if (!iqk_mul_mat_4d(src0_next->ne[1], ne11, ne00,
src0_next->ne[2], src0_next->ne[3], ne12, ne13, src0_next->nb[2], src0_next->nb[3], row_size*ne11, row_size*ne11*ne12,
dst_next->nb[2]/sizeof(float), dst_next->nb[3]/sizeof(float),
src0_next->type, src0_next->data, src0_next->nb[1],
vec_dot_type, wdata, row_size,
(float *)dst_next->data, dst_next->nb[1]/sizeof(float), ith, nth)) break;
++node_n;
}
}
return node_n;
}
#endif
#if GGML_USE_LLAMAFILE
if (src1->type != vec_dot_type) {
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
vec_dot_type,
dst->type))
goto UseGgmlGemm2;
return;
}
UseGgmlGemm2:;
#endif
if (ith == 0) {
atomic_store(&params->shared->current_chunk, nth);
@@ -15243,7 +15179,7 @@ UseGgmlGemm2:;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
if (src0_start >= src0_end) return;
if (src0_start >= src0_end) return node_n;
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (gemm && (ne11 > 3)) {
@@ -15255,7 +15191,7 @@ UseGgmlGemm2:;
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
return;
return node_n;
}
// The first chunk comes from our thread_id, the rest will get auto-assigned.
@@ -15279,6 +15215,8 @@ UseGgmlGemm2:;
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
}
return node_n;
}
// ggml_compute_forward_mul_mat_id
@@ -20392,7 +20330,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return false;
return i;
}
#if IK_PRINT_TIMING
@@ -20506,7 +20444,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
i = ggml_compute_forward_mul_mat(params, tensor, cgraph, i);
} break;
case GGML_OP_MUL_MAT_ID:
{