mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
Reduce memory usage for FlashMLA-2
Oh, also fix int overflow in the CUDA concat implementation. It is funny how the llama.cpp 64-bit police has gone (almost) everywhere and replaced 32-bit ints with 64-bit ints, needed or not, but hasn't done it where it is actually needed.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include "concat.cuh"
|
||||
|
||||
// contiguous kernels
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
|
||||
// contiguous kernels
|
||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00,
|
||||
int64_t nb02, int64_t nb12, int64_t nb2) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * nb2;
|
||||
|
||||
if (nidx < ne00) { // src0
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * nb02;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
(nidx - ne00) +
|
||||
blockIdx.y * (ne0 - ne00) +
|
||||
blockIdx.z * nb12;
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
|
||||
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@@ -90,13 +118,14 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
|
||||
const float * xi = x + i1*ne00;
|
||||
const float * yi = y + i1*(ne0 - ne00);
|
||||
float * dst_i = dst + i1*ne0;
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00);
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
if (dim == 0) {
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
|
||||
//concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
|
||||
return;
|
||||
}
|
||||
if (dim == 1) {
|
||||
|
||||
@@ -13757,55 +13757,80 @@ struct llm_build_context {
|
||||
|
||||
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
|
||||
|
||||
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
|
||||
cb(kv_f32, "kv_f32", il);
|
||||
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
|
||||
int n_max_head = n_head;
|
||||
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
|
||||
n_max_head /= 2; kv_f32_size /= 2;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(n_head % n_max_head == 0);
|
||||
|
||||
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
|
||||
cb(v_f32, "v_f32", il);
|
||||
|
||||
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
|
||||
cb(v, "v", il);
|
||||
|
||||
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
|
||||
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
||||
cb(k_nope_f32, "k_nope_f32", il);
|
||||
|
||||
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
ggml_build_forward_expand(gf, k_nope);
|
||||
ggml_build_forward_expand(gf, v);
|
||||
auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
|
||||
|
||||
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
|
||||
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
|
||||
|
||||
ggml_tensor repeater;
|
||||
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
|
||||
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
|
||||
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
|
||||
cb(k_rope, "k_rope", il);
|
||||
|
||||
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
|
||||
cb(k, "k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, k);
|
||||
|
||||
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_concat", il);
|
||||
|
||||
ggml_build_forward_expand(gf, q);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
if (q->ne[1] <= 8) {
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
}
|
||||
cb(kqv, "kqv", il);
|
||||
for (int iter = 0; iter < n_head/n_max_head; ++iter) {
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
|
||||
auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
|
||||
model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
|
||||
|
||||
auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
|
||||
cb(kv_f32, "kv_f32", il);
|
||||
|
||||
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head,
|
||||
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
|
||||
cb(v_f32, "v_f32", il);
|
||||
|
||||
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
|
||||
cb(v, "v", il);
|
||||
|
||||
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
|
||||
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
||||
cb(k_nope_f32, "k_nope_f32", il);
|
||||
|
||||
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
ggml_build_forward_expand(gf, k_nope);
|
||||
ggml_build_forward_expand(gf, v);
|
||||
|
||||
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
|
||||
cb(k, "k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, k);
|
||||
|
||||
auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
|
||||
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
if (q->ne[1] <= 8) {
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
}
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
if (iter == 0) {
|
||||
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
|
||||
} else {
|
||||
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
Reference in New Issue
Block a user