mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 09:09:50 +00:00
Handle quantized cache
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "fattn-mma-f16-interface.cuh"
|
||||
#include "fattn-new-mma.cuh"
|
||||
#include "fattn.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
@@ -112,6 +113,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
auto local_Q = *Q;
|
||||
auto local_dst = *dst;
|
||||
|
||||
ggml_tensor local_K, local_V;
|
||||
ggml_cuda_pool_alloc<half> K_f16(ctx.pool());
|
||||
if (ggml_is_quantized(K->type)) {
|
||||
// We need to dequantize here, else we will dequantize the same cache twice
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K->data, K_f16.ptr, 1, ggml_nelements(K), ctx.stream());
|
||||
|
||||
auto bs = ggml_blck_size(K->type);
|
||||
auto ts = ggml_type_size(K->type);
|
||||
|
||||
local_K = *K;
|
||||
local_K.data = K_f16.get();
|
||||
local_K.type = GGML_TYPE_F16;
|
||||
local_K.nb[0] = sizeof(half);
|
||||
local_K.nb[1] = local_K.nb[1]*bs*sizeof(half)/ts;
|
||||
local_K.nb[2] = local_K.nb[2]*bs*sizeof(half)/ts;
|
||||
local_K.nb[3] = local_K.nb[3]*bs*sizeof(half)/ts;
|
||||
local_dst.src[1] = &local_K;
|
||||
|
||||
local_V = local_K;
|
||||
local_V.ne[0] = V->ne[0];
|
||||
local_dst.src[2] = &local_V;
|
||||
}
|
||||
|
||||
local_Q.ne[2] = 16;
|
||||
local_dst.ne[1] = 16;
|
||||
local_dst.src[0] = &local_Q;
|
||||
|
||||
Reference in New Issue
Block a user