FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260)

* FlashMLA-2: eliminate intermediate f32 tensors

This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.

* FlashMLA-2: enable fast path only on the CPU for now

I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.

* FlashMLA-2: slightly smaller computer buffer size

* Prepare wk_b when loading DeepSeek models (if wk_b is missing)

* Add some comments

* Fix case where wkv_b is quantized with k- or i-quants.

* Fix CUDA

There is an issue with quantized GEMV on CUDA when the left operand
(the matrix) is not contiguous. So, for now, we also create wv_b
during model loading and use that instead of the 3D view of wkv_b.

* FlashMLA-2: avoid conversions to f32 also on CUDA

* Be able to compute for more than 65535 tokens

On CUDA just a quick hack that allows us to cancatenate tensors
with more than 65535 rows along zroth dimension as needed by
FlashMLA-2. Also needed some care in the perplexity tool to
avoid int overflows when evaluating the computed logits.

* Reduce memory usage for FlashMLA-2

Oh, also fix int overflow in the CUDA concat implementation.

It is funny how the llama.cpp 64-bit police has gone (almost) everywhere
and replaced 32-bit ints with 64-bit ints, needed or not,
but hasn't done it where it is actually needed.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-18 07:36:42 +01:00
committed by GitHub
parent 0f19a500a9
commit 9fe2b06f79
5 changed files with 184 additions and 125 deletions

View File

@@ -3354,7 +3354,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
return false;
}
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {

View File

@@ -248,17 +248,35 @@ static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
//GGML_ASSERT(src1->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
if (src1->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
else if (src1->type == GGML_TYPE_F16) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
else {
GGML_ABORT("fatal error");
}
}

View File

@@ -1,7 +1,7 @@
#include "concat.cuh"
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
}
}
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00,
int64_t nb02, int64_t nb12, int64_t nb2) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * nb2;
if (nidx < ne00) { // src0
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * nb02;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
(nidx - ne00) +
blockIdx.y * (ne0 - ne00) +
blockIdx.z * nb12;
dst[offset_dst] = y[offset_src];
}
}
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
}
}
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -81,9 +109,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
if (dim == 0 && ne1 >= 65536) {
int64_t nstep = (ne1 + 32767)/32768;
for (int64_t istep = 0; istep < nstep; ++istep) {
int64_t i1 = 32768*istep;
int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1;
dim3 gridDim(num_blocks, n1, ne2);
const float * xi = x + i1*ne00;
const float * yi = y + i1*(ne0 - ne00);
float * dst_i = dst + i1*ne0;
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
}
return;
}
dim3 gridDim(num_blocks, ne1, ne2);
if (dim == 0) {
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
//concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
return;
}
if (dim == 1) {
@@ -150,52 +192,77 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0];
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
(dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream));
return;
}
if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
// GGML_ABORT("fatal error");
//}
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *)src0->data,
(const char *)src1->data,
( char *)dst->data,
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim);
}
return;
}
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
// GGML_ABORT("fatal error");
//}
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
} else {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
//if (dim != 3) {
// for (int i3 = 0; i3 < dst->ne[3]; i3++) {
// concat_f32_cuda(
// src0_d + i3 * (src0->nb[3] / 4),
// src1_d + i3 * (src1->nb[3] / 4),
// dst_d + i3 * ( dst->nb[3] / 4),
// src0->ne[0], src0->ne[1], src0->ne[2],
// dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
// }
//} else {
// const size_t size0 = ggml_nbytes(src0);
// const size_t size1 = ggml_nbytes(src1);
// CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
// CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
//}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(