FlashMLA-2 (CPU): faster and smaller compute buffer size (#253)

* FlashMLA-2: eliminate intermediate f32 tensors

This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.

* FlashMLA-2: enable fast path only on the CPU for now

I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.

* FlashMLA-2: slightly smaller computer buffer size

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-13 12:07:43 +02:00
committed by GitHub
parent 3f23ed68f1
commit 305fabfc3b
5 changed files with 225 additions and 48 deletions

View File

@@ -843,7 +843,8 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
op->type != GGML_TYPE_IQ1_S &&
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
case GGML_OP_MUL_MAT:
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
return true;
//return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
default:
return true;
}

View File

@@ -12589,6 +12589,43 @@ static void ggml_compute_forward_repeat_f16(
}
}
static void ggml_compute_forward_repeat_any(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src = dst->src[0];
GGML_ASSERT(ggml_can_repeat(src, dst));
GGML_ASSERT(src->type == dst->type);
GGML_ASSERT(src->nb[0] == ggml_type_size(src->type));
int64_t src_row_size = ggml_row_size(src->type, src->ne[0]);
GGML_ASSERT((int64_t )dst->nb[1] == src_row_size*dst->ne[0]/src->ne[0]);
int ith = params->ith;
int nth = params->nth;
int64_t nrows = ggml_nrows(dst);
int64_t nrows_per_thread = (nrows + nth - 1)/nth;
int64_t first_row = ith*nrows_per_thread;
if (first_row >= nrows) return;
int64_t last_row = MIN(first_row + nrows_per_thread, nrows);
for (int64_t row = first_row; row < last_row; ++row) {
int64_t i3 = row/(dst->ne[1]*dst->ne[2]);
int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1];
int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1];
char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3];
int64_t i03 = i3 % src->ne[3];
int64_t i02 = i2 % src->ne[2];
int64_t i01 = i1 % src->ne[1];
const char * x = (const char *)src->data + i01*src->nb[1] + i02*src->nb[2] + i03*src->nb[3];
for (int64_t ir = 0; ir < dst->ne[0]/src->ne[0]; ++ir) {
memcpy(y, x, src_row_size);
y += src_row_size;
}
}
}
static void ggml_compute_forward_repeat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -12609,7 +12646,8 @@ static void ggml_compute_forward_repeat(
} break;
default:
{
GGML_ABORT("fatal error");
ggml_compute_forward_repeat_any(params, dst);
//GGML_ABORT("fatal error");
}
}
}
@@ -12762,6 +12800,44 @@ static void ggml_compute_forward_concat_f32(
}
}
static void ggml_compute_forward_concat_any(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
const int32_t dim = ggml_get_op_params_i32(dst, 0);
// Let's do it for dim = 0 only for now
GGML_ASSERT(dim == 0);
int ith = params->ith;
int nth = params->nth;
int64_t nrows = ggml_nrows(dst);
int64_t nrows_per_thread = (nrows + nth - 1)/nth;
int64_t first_row = ith*nrows_per_thread;
if (first_row >= nrows) return;
int64_t last_row = MIN(first_row + nrows_per_thread, nrows);
int64_t src0_row_size = ggml_row_size(src0->type, src0->ne[0]);
int64_t src1_row_size = ggml_row_size(src1->type, src1->ne[0]);
for (int64_t row = first_row; row < last_row; ++row) {
int64_t i3 = row/(dst->ne[1]*dst->ne[2]);
int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1];
int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1];
char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3];
const char * x0 = (const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3];
const char * x1 = (const char *)src1->data + i1*src1->nb[1] + i2*src1->nb[2] + i3*src1->nb[3];
memcpy(y, x0, src0_row_size);
memcpy(y + src0_row_size, x1, src1_row_size);
}
}
static void ggml_compute_forward_concat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -12776,7 +12852,8 @@ static void ggml_compute_forward_concat(
} break;
default:
{
GGML_ABORT("fatal error");
ggml_compute_forward_concat_any(params, dst);
//GGML_ABORT("fatal error");
}
}
}
@@ -14302,7 +14379,17 @@ UseGgmlGemm1:;
const size_t nbw3 = nbw2*ne12;
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
if (src1->type != GGML_TYPE_F32) {
#if GGML_USE_IQK_MULMAT
char * work_buffer = wdata + ne13*nbw3 + ith*ne10*sizeof(float);
GGML_ASSERT(params->wsize >= ne13*nbw3 + nth*ne10*sizeof(float));
iqk_quantize_any(src1->type, vec_dot_type, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
src1->data, wdata, work_buffer, type_traits[src1->type].to_float, from_float, ith, nth);
#else
GGML_ABORT("fatal error");
#endif
}
else {
//#ifdef GGML_USE_IQK_MULMAT
// int ts = type_traits[vec_dot_type].type_size;
@@ -14348,6 +14435,7 @@ UseGgmlGemm1:;
}
}
//#endif
}
ggml_barrier(params->shared);
@@ -16250,28 +16338,28 @@ static void ggml_compute_forward_soft_max_f32(
}
}
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
}
#endif
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// //printf("p[%d] = %f\n", i, p[i]);
// assert(!isnan(wp[i]));
// }
//#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
assert(sum > 0.0);
//assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(nc, dp, sum);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(dp[i]));
assert(!isinf(dp[i]));
}
#endif
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// assert(!isnan(dp[i]));
// assert(!isinf(dp[i]));
// }
//#endif
}
}
@@ -21498,6 +21586,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
if (node->src[1]->type != GGML_TYPE_F32) {
cur += n_tasks*node->src[1]->ne[0]*sizeof(float); // src1->type -> f32 -> vec_dot_type
}
}
} break;
case GGML_OP_MUL_MAT_ID:

View File

@@ -185,6 +185,34 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
}
void iqk_quantize_any(int from_type, int to_type,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3,
const void * x, void * y, void * work_buffer,
to_float_t to_float, from_float_t from_float, int ith, int nth) {
auto type_x = ggml_type(from_type);
GGML_ASSERT(ggml_type_size(type_x) == nb0);
auto type_y = ggml_type(to_type);
auto row_size_y = ggml_row_size(type_y, ne0);
int64_t nrows = ne1*ne2*ne3;
int64_t nrows_per_thread = (nrows + nth - 1)/nth;
int64_t first_row = nrows_per_thread*ith;
if (first_row >= nrows) return;
int64_t last_row = std::min(first_row + nrows_per_thread, nrows);
for (int64_t row = first_row; row < last_row; ++row) {
int64_t i3 = row/(ne1*ne2);
int64_t i2 = (row - i3*ne1*ne2)/ne1;
int64_t i1 = row - i3*ne1*ne2 - i2*ne1;
const char * cx = (const char *)x + i1*nb1 + i2*nb2 + i3*nb3;
// TODO: special case common types such as f16, q8_0
// (although the performance gains may be too small to justify the added complexity)
to_float((const void *)cx, (float *)work_buffer, ne0);
auto cy = (char *)y + (i3*ne1*ne2 + i2*ne1 + i1)*row_size_y;
from_float((const float *)work_buffer, (void *)cy, ne0);
}
}
size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
IQ1BNQuantizer iq1bn;
auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row);

View File

@@ -248,6 +248,14 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor);
// So we can re-pack Microsoft's BitNet I2_S quants
void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
typedef void (*to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
typedef void (*from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void iqk_quantize_any(int from_type, int to_type,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3,
const void * GGML_RESTRICT x, void * GGML_RESTRICT y, void * work_buffer,
to_float_t to_float, from_float_t from_float, int ith, int nth);
#ifdef __cplusplus
}
#endif