Reduce memory usage for FlashMLA-2

Oh, also fix int overflow in the CUDA concat implementation.

It is funny how the llama.cpp 64-bit police has gone (almost) everywhere
and replaced 32-bit ints with 64-bit ints, needed or not,
but hasn't done it where it is actually needed.
This commit is contained in:
Iwan Kawrakow
2025-03-17 15:00:26 +02:00
parent b9daa401d7
commit b147e31f5a
2 changed files with 91 additions and 37 deletions

View File

@@ -1,7 +1,7 @@
#include "concat.cuh"
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
}
}
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00,
int64_t nb02, int64_t nb12, int64_t nb2) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * nb2;
if (nidx < ne00) { // src0
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * nb02;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
(nidx - ne00) +
blockIdx.y * (ne0 - ne00) +
blockIdx.z * nb12;
dst[offset_dst] = y[offset_src];
}
}
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
}
}
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -90,13 +118,14 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
const float * xi = x + i1*ne00;
const float * yi = y + i1*(ne0 - ne00);
float * dst_i = dst + i1*ne0;
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00);
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
}
return;
}
dim3 gridDim(num_blocks, ne1, ne2);
if (dim == 0) {
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
//concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
return;
}
if (dim == 1) {

View File

@@ -13757,55 +13757,80 @@ struct llm_build_context {
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
int n_max_head = n_head;
if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
n_max_head /= 2; kv_f32_size /= 2;
}
}
GGML_ASSERT(n_head % n_max_head == 0);
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
cb(k_nope, "k_nope", il);
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
cb(k_rope, "k_rope", il);
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_concat", il);
ggml_build_forward_expand(gf, q);
kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
if (q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
cb(kqv, "kqv", il);
for (int iter = 0; iter < n_head/n_max_head; ++iter) {
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
cb(k_nope, "k_nope", il);
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
if (q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
cb(kqv, "kqv", il);
if (iter == 0) {
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
} else {
cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
}
}
}
else {