Make quantized KV cache work

This commit is contained in:
Kawrakow
2026-01-25 05:51:44 +00:00
parent 6a5111c215
commit 4d5dcba7c9

View File

@@ -75,10 +75,10 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
auto V = dst->src[2];
GGML_ASSERT(Q->ne[2] / K->ne[2] == 12);
ggml_cuda_pool_alloc<half> k_data(ctx.pool());
ggml_cuda_pool_alloc<half> v_data(ctx.pool());
ggml_cuda_pool_alloc<float> q_data(ctx.pool(), ggml_nelements(Q));
ggml_cuda_pool_alloc<float> dst_data(ctx.pool(), ggml_nelements(dst));
ggml_cuda_pool_alloc<half> k_data(ctx.pool());
ggml_cuda_pool_alloc<half> v_data(ctx.pool());
repack_q(Q, q_data.get(), 8, 4, K->ne[2], ctx.stream());
@@ -102,6 +102,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
to_fp_16(K->data, k_data.get(), 1, nelem, ctx.stream());
local_K.type = GGML_TYPE_F16;
local_K.data = k_data.get();
auto ts = ggml_type_size(K->type);
auto bs = ggml_blck_size(K->type);
local_K.nb[0] = sizeof(half);
local_K.nb[1] = sizeof(half)*bs * local_K.nb[1]/ts;
local_K.nb[2] = sizeof(half)*bs * local_K.nb[2]/ts;
local_K.nb[3] = sizeof(half)*bs * local_K.nb[3]/ts;
}
if (V->type != GGML_TYPE_F16) {
auto nelem = ggml_nelements(V);
@@ -110,6 +116,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
to_fp_16(V->data, v_data.get(), 1, nelem, ctx.stream());
local_V.type = GGML_TYPE_F16;
local_V.data = v_data.get();
auto ts = ggml_type_size(V->type);
auto bs = ggml_blck_size(V->type);
local_V.nb[0] = sizeof(half);
local_V.nb[1] = sizeof(half)*bs * local_V.nb[1]/ts;
local_V.nb[2] = sizeof(half)*bs * local_V.nb[2]/ts;
local_V.nb[3] = sizeof(half)*bs * local_V.nb[3]/ts;
}
constexpr int n_op_params = GGML_MAX_OP_PARAMS / sizeof(int);