mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
Fix CUDA FlashMLA-3 with quantized KV cache (#400)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1362,26 +1362,46 @@ void launch_fattn_new_mma(
|
|||||||
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
|
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
|
||||||
K_data = (char *) K_f16.ptr;
|
K_data = (char *) K_f16.ptr;
|
||||||
|
|
||||||
const size_t bs = ggml_blck_size(K->type);
|
nb11 = K->ne[0]*sizeof(half);
|
||||||
const size_t ts = ggml_type_size(K->type);
|
nb12 = nb11*K->ne[1];
|
||||||
|
nb13 = nb12*K->ne[2];
|
||||||
|
|
||||||
nb11 = nb11*bs*sizeof(half)/ts;
|
// Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
|
||||||
nb12 = nb12*bs*sizeof(half)/ts;
|
// gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
|
||||||
nb13 = nb13*bs*sizeof(half)/ts;
|
//const size_t bs = ggml_blck_size(K->type);
|
||||||
|
//const size_t ts = ggml_type_size(K->type);
|
||||||
|
|
||||||
|
//nb11 = nb11*bs*sizeof(half)/ts;
|
||||||
|
//nb12 = nb12*bs*sizeof(half)/ts;
|
||||||
|
//nb13 = nb13*bs*sizeof(half)/ts;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||||
V_f16.alloc(ggml_nelements(V));
|
if constexpr (DV == 512) {
|
||||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
// DeepSeek. In this case the V cache is the same as the K cache, except that
|
||||||
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
|
// it has 512 elements per row instead of 576.
|
||||||
V_data = (char *) V_f16.ptr;
|
nb21 = nb11;
|
||||||
|
nb22 = nb12;
|
||||||
|
nb23 = nb13;
|
||||||
|
V_data = K_data;
|
||||||
|
} else {
|
||||||
|
V_f16.alloc(ggml_nelements(V));
|
||||||
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||||
|
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
|
||||||
|
V_data = (char *) V_f16.ptr;
|
||||||
|
|
||||||
const size_t bs = ggml_blck_size(V->type);
|
nb21 = K->ne[0]*sizeof(half);
|
||||||
const size_t ts = ggml_type_size(V->type);
|
nb22 = nb21*V->ne[1];
|
||||||
|
nb23 = nb22*V->ne[2];
|
||||||
|
|
||||||
nb21 = nb21*bs*sizeof(half)/ts;
|
// Original PR in llama.cpp. Same comment as above for the K cache.
|
||||||
nb22 = nb22*bs*sizeof(half)/ts;
|
//const size_t bs = ggml_blck_size(V->type);
|
||||||
nb23 = nb23*bs*sizeof(half)/ts;
|
//const size_t ts = ggml_type_size(V->type);
|
||||||
|
|
||||||
|
//nb21 = nb21*bs*sizeof(half)/ts;
|
||||||
|
//nb22 = nb22*bs*sizeof(half)/ts;
|
||||||
|
//nb23 = nb23*bs*sizeof(half)/ts;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int parallel_blocks = 1;
|
int parallel_blocks = 1;
|
||||||
|
|||||||
Reference in New Issue
Block a user