Make Q8_0 KV cache work with mla=2,fa on CUDA (#264)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-18 15:40:47 +01:00
committed by GitHub
parent f4ebf13b6a
commit 68a5b60408
5 changed files with 117 additions and 46 deletions

View File

@@ -3395,6 +3395,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
return false;
}
//==================================================================
//if (ggml_is_quantized(a->type) && ggml_is_quantized(b->type)) {
// return false;
//}
//==================================================================
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
return false;
}
@@ -3496,6 +3501,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
return true;
}
if (ggml_is_quantized(src0_type) && (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_F32)) {
return true;
}
if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) {
if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) {
return true;

View File

@@ -282,7 +282,28 @@ static void ggml_cuda_op_bin_bcast(
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
GGML_ASSERT(dst->type == dst->src[0]->type);
if (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
return;
}
auto src = dst->src[0];
auto bs = ggml_blck_size(src->type);
auto ts = ggml_type_size(src->type);
if (src->nb[0] != ts || ts*(src->ne[0]/bs) % 2 != 0) {
fprintf(stderr, "%s: unsupported case type = %s, nb[0] = %zu, type_size = %zu\n", __func__, ggml_type_name(src->type), src->nb[0], ts);
GGML_ABORT("fatal error");
}
auto aux_src = *src;
aux_src.type = GGML_TYPE_F16;
aux_src.ne[0] = ts*(src->ne[0]/bs)/2;
aux_src.nb[0] = 2;
auto aux_dst = *dst;
aux_dst.type = GGML_TYPE_F16;
aux_dst.ne[0] = ts*(dst->ne[0]/bs)/2;
aux_dst.nb[0] = 2;
aux_dst.src[0] = &aux_src;
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -209,6 +209,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
auto bs = ggml_blck_size(dst->type);
auto ts = ggml_type_size(dst->type);
auto ne00_eff = (src0->ne[0]/bs)*ts/sizeof(float);
auto ne0_eff = (dst->ne[0]/bs)*ts/sizeof(float);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
@@ -217,25 +221,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
//printf("%s(%s, %s): %ld %zu %zu %ld %zu %zu\n", __func__, src0->name, src1->name, src0->ne[0], src0->nb[0], src0->nb[1], dst->ne[0], dst->nb[0], dst->nb[1]);
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
ne00_eff, src0->ne[1], src0->ne[2],
ne0_eff, dst->ne[1], dst->ne[2], dim, stream);
//src0->nb[1]/sizeof(float), src0->ne[1], src0->ne[2],
//dst->nb[1]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
//src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
//dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
}
} else {
//printf("%s(not contiguous): %s(%s) and %s(%s)\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type));
auto ne10_eff = (src1->ne[0]/bs)*ts/sizeof(float);
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *)src0->data,
(const char *)src1->data,
( char *)dst->data,
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
ne00_eff, src0->ne[1], src0->ne[2], src0->ne[3],
//src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
ne10_eff, src1->ne[1], src1->ne[2], src1->ne[3],
//src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
ne0_eff, dst->ne[1], dst->ne[2], dst->ne[3],
//dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim);
}
return;

View File

@@ -66,25 +66,30 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}
//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
//
// if (i >= ne) {
// return;
// }
//
// const int64_t i03 = i/(ne00 * ne01 * ne02);
// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
//
// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
// const int ib = i00/QK8_0;
// const int iq = i00%QK8_0;
//
// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
//}
template <typename dst_t>
static __global__ void k_cpy_q8_0_to_float(const char * cx, dst_t * dst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
const int ib = i00/QK8_0;
const int iq = i00%QK8_0;
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16(__half2float(q8[ib].d)*q8[ib].qs[iq]);
} else {
dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
}
}
static __global__ void k_transpose_q8_0(const char * cx, char * cdst,
const int ne10, const int ne11, const int ne12,
@@ -532,23 +537,26 @@ static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor *
(const char *)src->data, (char *)dst->data,
dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3],
dst->nb[1], dst->nb[2], dst->nb[3]);
}
//auto ne = ggml_nelements(dst);
//ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
//const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
//auto aux_src = *dst;
//aux_src.nb[0] = sizeof(float);
//aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
//aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
//aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
//cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
// ((const char *)src->data, dst_f32.get(), ne,
// src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
//CUDA_CHECK(cudaGetLastError());
//aux_src.type = GGML_TYPE_F32;
//ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
// aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
// dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
static void copy_q8_0_to_float(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
auto stream = ctx.stream();
auto num_blocks = ggml_nelements(dst)/QK8_0;
if (dst->type == GGML_TYPE_F16) {
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (half *)dst->data, ggml_nelements(dst),
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
}
else if (dst->type == GGML_TYPE_F32) {
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (float *)dst->data, ggml_nelements(dst),
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
}
else if (dst->type == GGML_TYPE_BF16) {
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (nv_bfloat16 *)dst->data, ggml_nelements(dst),
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
}
else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
@@ -607,8 +615,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
copy_q8_0_to_float(ctx, src0, src1);
} else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) {
if (src1->type == GGML_TYPE_F16) {
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);
@@ -670,6 +683,9 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
return (void*)copy_q8_0_to_float;
} else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) {
if (src1->type == GGML_TYPE_F16) {
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);

View File

@@ -13771,9 +13771,21 @@ struct llm_build_context {
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
// There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when
// the KV cache is quantized. Hence, in that case we will simply use fp16 for now.
// The downside of the following line is that fp16 will be used even if attention is computed on the CPU
// if the build is with CUDA enabled.
auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16;
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
ggml_tensor * k_rope;
if (kv_cache_rope->type == kv_type) {
k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
} else {
auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16);
k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater);
}
cb(k_rope, "k_rope", il);
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
@@ -13796,15 +13808,15 @@ struct llm_build_context {
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
auto v = ggml_cast(ctx0, v_f32, kv_type);
cb(v, "v", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
cb(k_nope, "k_nope", il);
ggml_build_forward_expand(gf, k_nope);