mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
Make quantized KV cache work
This commit is contained in:
@@ -75,10 +75,10 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
auto V = dst->src[2];
|
auto V = dst->src[2];
|
||||||
GGML_ASSERT(Q->ne[2] / K->ne[2] == 12);
|
GGML_ASSERT(Q->ne[2] / K->ne[2] == 12);
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<half> k_data(ctx.pool());
|
|
||||||
ggml_cuda_pool_alloc<half> v_data(ctx.pool());
|
|
||||||
ggml_cuda_pool_alloc<float> q_data(ctx.pool(), ggml_nelements(Q));
|
ggml_cuda_pool_alloc<float> q_data(ctx.pool(), ggml_nelements(Q));
|
||||||
ggml_cuda_pool_alloc<float> dst_data(ctx.pool(), ggml_nelements(dst));
|
ggml_cuda_pool_alloc<float> dst_data(ctx.pool(), ggml_nelements(dst));
|
||||||
|
ggml_cuda_pool_alloc<half> k_data(ctx.pool());
|
||||||
|
ggml_cuda_pool_alloc<half> v_data(ctx.pool());
|
||||||
|
|
||||||
repack_q(Q, q_data.get(), 8, 4, K->ne[2], ctx.stream());
|
repack_q(Q, q_data.get(), 8, 4, K->ne[2], ctx.stream());
|
||||||
|
|
||||||
@@ -102,6 +102,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
to_fp_16(K->data, k_data.get(), 1, nelem, ctx.stream());
|
to_fp_16(K->data, k_data.get(), 1, nelem, ctx.stream());
|
||||||
local_K.type = GGML_TYPE_F16;
|
local_K.type = GGML_TYPE_F16;
|
||||||
local_K.data = k_data.get();
|
local_K.data = k_data.get();
|
||||||
|
auto ts = ggml_type_size(K->type);
|
||||||
|
auto bs = ggml_blck_size(K->type);
|
||||||
|
local_K.nb[0] = sizeof(half);
|
||||||
|
local_K.nb[1] = sizeof(half)*bs * local_K.nb[1]/ts;
|
||||||
|
local_K.nb[2] = sizeof(half)*bs * local_K.nb[2]/ts;
|
||||||
|
local_K.nb[3] = sizeof(half)*bs * local_K.nb[3]/ts;
|
||||||
}
|
}
|
||||||
if (V->type != GGML_TYPE_F16) {
|
if (V->type != GGML_TYPE_F16) {
|
||||||
auto nelem = ggml_nelements(V);
|
auto nelem = ggml_nelements(V);
|
||||||
@@ -110,6 +116,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
to_fp_16(V->data, v_data.get(), 1, nelem, ctx.stream());
|
to_fp_16(V->data, v_data.get(), 1, nelem, ctx.stream());
|
||||||
local_V.type = GGML_TYPE_F16;
|
local_V.type = GGML_TYPE_F16;
|
||||||
local_V.data = v_data.get();
|
local_V.data = v_data.get();
|
||||||
|
auto ts = ggml_type_size(V->type);
|
||||||
|
auto bs = ggml_blck_size(V->type);
|
||||||
|
local_V.nb[0] = sizeof(half);
|
||||||
|
local_V.nb[1] = sizeof(half)*bs * local_V.nb[1]/ts;
|
||||||
|
local_V.nb[2] = sizeof(half)*bs * local_V.nb[2]/ts;
|
||||||
|
local_V.nb[3] = sizeof(half)*bs * local_V.nb[3]/ts;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int n_op_params = GGML_MAX_OP_PARAMS / sizeof(int);
|
constexpr int n_op_params = GGML_MAX_OP_PARAMS / sizeof(int);
|
||||||
|
|||||||
Reference in New Issue
Block a user