From 030ba3aebf3f48c477d5f03cd359937e4790fb20 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 5 Jul 2025 19:10:56 +0300 Subject: [PATCH] Trying to implement quantized fmoe - not working yet --- ggml/src/ggml-cuda.cu | 65 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e0035c7a..2dc48d4d 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2673,7 +2673,25 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } } } else { - ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]); + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + bool use_quantized_src1 = false; + int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; + if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); + src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + src1_quantized.alloc(src1_quantized_size); + quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[2], src1->ne[3], src1_padded_num_cols, src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + use_quantized_src1 = true; + } + ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); + if (use_quantized_src1) { + src1_contiguous.alloc(src1_quantized_size); + } else { + src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); + } ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); @@ -2704,7 +2722,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor if (num_src1_rows == 0) continue; size_t mapping_offset = cum_moe_counts[i02]; - { + if (use_quantized_src1) { + unsigned int eff_ne10 = src1_padded_row_size/sizeof(float); + dim3 block_dims(std::min(eff_ne10, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_quantized.get(), src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, eff_ne10, ne11, src1_padded_row_size, src1_padded_row_size); + CUDA_CHECK(cudaGetLastError()); + src1_row.nb[0] = sizeof(block_q8_1); + src1_row.type = GGML_TYPE_Q8_1; + } + else { dim3 block_dims(std::min((unsigned int)ne10, 768u)); dim3 grid_dims(num_src1_rows); k_copy_src_to_contiguous<<>>( @@ -2719,21 +2747,44 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(nb1 == sizeof(float)*ne0); src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; - src1_row.nb[2] = num_src1_rows*nb11; - src1_row.nb[3] = num_src1_rows*nb11; + src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; + src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; + src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; +//struct mmq_args { +// const char * x; const char * y; float * dst; +// int64_t ne00; int64_t ne01; int64_t stride01; +// int64_t ne10; int64_t ne11; int64_t stride11; +// int64_t ne0; +//}; + +// const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; + + //ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, + // (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(), + // 0, src0_1->ne[1], 1, src1_padded_col_size, stream); + dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_contiguous.get(), (float *)dst_row.data, + 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + } CUDA_CHECK(cudaGetLastError()); dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_contiguous.get(), (float *)dst_row.data, + 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + } CUDA_CHECK(cudaGetLastError()); ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),