FlashMLA-2 (CPU): faster and smaller compute buffer size (#253)

* FlashMLA-2: eliminate intermediate f32 tensors

This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.

* FlashMLA-2: enable fast path only on the CPU for now

I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.

* FlashMLA-2: slightly smaller computer buffer size

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-13 12:07:43 +02:00
committed by GitHub
parent 3f23ed68f1
commit 305fabfc3b
5 changed files with 225 additions and 48 deletions

View File

@@ -185,6 +185,34 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
}
void iqk_quantize_any(int from_type, int to_type,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3,
uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3,
const void * x, void * y, void * work_buffer,
to_float_t to_float, from_float_t from_float, int ith, int nth) {
auto type_x = ggml_type(from_type);
GGML_ASSERT(ggml_type_size(type_x) == nb0);
auto type_y = ggml_type(to_type);
auto row_size_y = ggml_row_size(type_y, ne0);
int64_t nrows = ne1*ne2*ne3;
int64_t nrows_per_thread = (nrows + nth - 1)/nth;
int64_t first_row = nrows_per_thread*ith;
if (first_row >= nrows) return;
int64_t last_row = std::min(first_row + nrows_per_thread, nrows);
for (int64_t row = first_row; row < last_row; ++row) {
int64_t i3 = row/(ne1*ne2);
int64_t i2 = (row - i3*ne1*ne2)/ne1;
int64_t i1 = row - i3*ne1*ne2 - i2*ne1;
const char * cx = (const char *)x + i1*nb1 + i2*nb2 + i3*nb3;
// TODO: special case common types such as f16, q8_0
// (although the performance gains may be too small to justify the added complexity)
to_float((const void *)cx, (float *)work_buffer, ne0);
auto cy = (char *)y + (i3*ne1*ne2 + i2*ne1 + i1)*row_size_y;
from_float((const float *)work_buffer, (void *)cy, ne0);
}
}
size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
IQ1BNQuantizer iq1bn;
auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row);