diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index f8b86e29..a9b64ff0 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -366,20 +366,22 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * || GGML_CUDA_CC_IS_CDNA(cc); if (!ids_tensor) { - const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + - get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); - ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); - { + ggml_cuda_pool_alloc src1_q8_1(ctx.pool()); + if (!src1_quantized_data) { + const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + src1_q8_1.alloc(nbytes_src1_q8_1); quantize_mmq_q8_1_cuda(src1_d, src1_q8_1.get(), ne10, ne11, 1, ne10_padded, src0->type, stream); CUDA_CHECK(cudaGetLastError()); + src1_quantized_data = src1_q8_1.get(); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); const int64_t s13 = ne12*s12; const mmq_args_id args = { - src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + src0_d, src0->type, (const int *)src1_quantized_data, nullptr, nullptr, dst_d, ne00, ne01, ne1, s01, ne11, s1, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3,