mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 08:20:09 +00:00
FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260)
* FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size * Prepare wk_b when loading DeepSeek models (if wk_b is missing) * Add some comments * Fix case where wkv_b is quantized with k- or i-quants. * Fix CUDA There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is not contiguous. So, for now, we also create wv_b during model loading and use that instead of the 3D view of wkv_b. * FlashMLA-2: avoid conversions to f32 also on CUDA * Be able to compute for more than 65535 tokens On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits. * Reduce memory usage for FlashMLA-2 Oh, also fix int overflow in the CUDA concat implementation. It is funny how the llama.cpp 64-bit police has gone (almost) everywhere and replaced 32-bit ints with 64-bit ints, needed or not, but hasn't done it where it is actually needed. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -166,7 +166,7 @@ static void process_logits(
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
||||
const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]);
|
||||
const double v = -results.log_softmax;
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
@@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]);
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
}
|
||||
@@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
|
||||
if (num_batches > 1 && n_outputs > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3354,7 +3354,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
|
||||
return false;
|
||||
}
|
||||
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
|
||||
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
|
||||
return false;
|
||||
}
|
||||
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
||||
|
||||
@@ -248,17 +248,35 @@ static void ggml_cuda_op_bin_bcast(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
//GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
if (src1->type == GGML_TYPE_F32) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
else if (src1->type == GGML_TYPE_F16) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "concat.cuh"
|
||||
|
||||
// contiguous kernels
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
|
||||
// contiguous kernels
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00,
|
||||
int64_t nb02, int64_t nb12, int64_t nb2) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * nb2;
|
||||
|
||||
if (nidx < ne00) { // src0
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * nb02;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
(nidx - ne00) +
|
||||
blockIdx.y * (ne0 - ne00) +
|
||||
blockIdx.z * nb12;
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
|
||||
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -81,9 +109,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
|
||||
|
||||
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
|
||||
if (dim == 0 && ne1 >= 65536) {
|
||||
int64_t nstep = (ne1 + 32767)/32768;
|
||||
for (int64_t istep = 0; istep < nstep; ++istep) {
|
||||
int64_t i1 = 32768*istep;
|
||||
int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1;
|
||||
dim3 gridDim(num_blocks, n1, ne2);
|
||||
const float * xi = x + i1*ne00;
|
||||
const float * yi = y + i1*(ne0 - ne00);
|
||||
float * dst_i = dst + i1*ne0;
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
if (dim == 0) {
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
|
||||
//concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
|
||||
return;
|
||||
}
|
||||
if (dim == 1) {
|
||||
@@ -150,52 +192,77 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
(dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
|
||||
const size_t size0 = ggml_nbytes(src0);
|
||||
const size_t size1 = ggml_nbytes(src1);
|
||||
CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
return;
|
||||
}
|
||||
|
||||
if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
|
||||
src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
|
||||
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
|
||||
// GGML_ABORT("fatal error");
|
||||
//}
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_cuda(
|
||||
src0_d + i3 * (src0->nb[3] / 4),
|
||||
src1_d + i3 * (src1->nb[3] / 4),
|
||||
dst_d + i3 * ( dst->nb[3] / 4),
|
||||
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
|
||||
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
|
||||
}
|
||||
} else {
|
||||
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *)src0->data,
|
||||
(const char *)src1->data,
|
||||
( char *)dst->data,
|
||||
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
|
||||
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
|
||||
// GGML_ABORT("fatal error");
|
||||
//}
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
|
||||
const size_t size0 = ggml_nbytes(src0);
|
||||
const size_t size1 = ggml_nbytes(src1);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_cuda(
|
||||
src0_d + i3 * (src0->nb[3] / 4),
|
||||
src1_d + i3 * (src1->nb[3] / 4),
|
||||
dst_d + i3 * ( dst->nb[3] / 4),
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
}
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_cuda(
|
||||
src0_d + i3 * (src0->nb[3] / 4),
|
||||
src1_d + i3 * (src1->nb[3] / 4),
|
||||
dst_d + i3 * ( dst->nb[3] / 4),
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
}
|
||||
|
||||
//if (dim != 3) {
|
||||
// for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
// concat_f32_cuda(
|
||||
// src0_d + i3 * (src0->nb[3] / 4),
|
||||
// src1_d + i3 * (src1->nb[3] / 4),
|
||||
// dst_d + i3 * ( dst->nb[3] / 4),
|
||||
// src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
// dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
// }
|
||||
//} else {
|
||||
// const size_t size0 = ggml_nbytes(src0);
|
||||
// const size_t size1 = ggml_nbytes(src1);
|
||||
|
||||
// CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
// CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
//}
|
||||
} else {
|
||||
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
|
||||
128
src/llama.cpp
128
src/llama.cpp
@@ -13755,31 +13755,52 @@ struct llm_build_context {
|
||||
|
||||
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
|
||||
|
||||
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
|
||||
int n_max_head = n_head;
|
||||
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
n_max_head /= 2; kv_f32_size /= 2;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(n_head % n_max_head == 0);
|
||||
|
||||
// For now this only works in the CPU implementation, so we only use it if there is just the CPU backend.
|
||||
// If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not
|
||||
// be taken even if no layers were offloaded to the GPU.
|
||||
if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) {
|
||||
auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
|
||||
|
||||
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
|
||||
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
|
||||
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
|
||||
|
||||
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
|
||||
ggml_tensor repeater;
|
||||
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
|
||||
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
|
||||
cb(k_rope, "k_rope", il);
|
||||
|
||||
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_concat", il);
|
||||
|
||||
ggml_build_forward_expand(gf, q);
|
||||
|
||||
for (int iter = 0; iter < n_head/n_max_head; ++iter) {
|
||||
|
||||
auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
|
||||
model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
|
||||
|
||||
auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
|
||||
cb(kv_f32, "kv_f32", il);
|
||||
|
||||
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head,
|
||||
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
|
||||
cb(v_f32, "v_f32", il);
|
||||
|
||||
v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
|
||||
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
|
||||
cb(v, "v", il);
|
||||
|
||||
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
|
||||
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
||||
cb(k_nope_f32, "k_nope_f32", il);
|
||||
|
||||
@@ -13789,74 +13810,27 @@ struct llm_build_context {
|
||||
ggml_build_forward_expand(gf, k_nope);
|
||||
ggml_build_forward_expand(gf, v);
|
||||
|
||||
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
|
||||
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
|
||||
|
||||
ggml_tensor repeater;
|
||||
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
|
||||
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
|
||||
cb(k_rope, "k_rope", il);
|
||||
|
||||
k = ggml_concat(ctx0, k_nope, k_rope, 0);
|
||||
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
|
||||
cb(k, "k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, k);
|
||||
|
||||
auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
|
||||
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
if (q->ne[1] <= 8) {
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
}
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
if (iter == 0) {
|
||||
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
|
||||
} else {
|
||||
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
// Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not
|
||||
// provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
|
||||
// multiplication, which *must* be f32.
|
||||
auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
|
||||
auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
|
||||
cb(kv_cache_view_f32, "kv_cache_view_f32", il);
|
||||
|
||||
// The no- and rotational position encoding portions of the KV cache
|
||||
auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
|
||||
auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
|
||||
kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
|
||||
|
||||
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
|
||||
cb(kv_f32, "kv_f32", il);
|
||||
|
||||
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
||||
cb(k_nope_f32, "k_nope_f32", il);
|
||||
|
||||
ggml_tensor repeater;
|
||||
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
|
||||
auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
|
||||
cb(k_rope_f32, "k_rope_f32", il);
|
||||
|
||||
auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
|
||||
cb(k_f32, "k_f32", il);
|
||||
|
||||
k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
|
||||
cb(k, "k", il);
|
||||
|
||||
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
|
||||
cb(v_f32, "v_f32", il);
|
||||
|
||||
v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
|
||||
cb(v, "v", il);
|
||||
}
|
||||
|
||||
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_concat", il);
|
||||
|
||||
ggml_build_forward_expand(gf, q);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
if (q->ne[1] <= 8) {
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
}
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
Reference in New Issue
Block a user