FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260)

* FlashMLA-2: eliminate intermediate f32 tensors

This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.

* FlashMLA-2: enable fast path only on the CPU for now

I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.

* FlashMLA-2: slightly smaller computer buffer size

* Prepare wk_b when loading DeepSeek models (if wk_b is missing)

* Add some comments

* Fix case where wkv_b is quantized with k- or i-quants.

* Fix CUDA

There is an issue with quantized GEMV on CUDA when the left operand
(the matrix) is not contiguous. So, for now, we also create wv_b
during model loading and use that instead of the 3D view of wkv_b.

* FlashMLA-2: avoid conversions to f32 also on CUDA

* Be able to compute for more than 65535 tokens

On CUDA just a quick hack that allows us to cancatenate tensors
with more than 65535 rows along zroth dimension as needed by
FlashMLA-2. Also needed some care in the perplexity tool to
avoid int overflows when evaluating the computed logits.

* Reduce memory usage for FlashMLA-2

Oh, also fix int overflow in the CUDA concat implementation.

It is funny how the llama.cpp 64-bit police has gone (almost) everywhere
and replaced 32-bit ints with 64-bit ints, needed or not,
but hasn't done it where it is actually needed.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-18 07:36:42 +01:00
committed by GitHub
parent 0f19a500a9
commit 9fe2b06f79
5 changed files with 184 additions and 125 deletions

View File

@@ -166,7 +166,7 @@ static void process_logits(
break;
}
lock.unlock();
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]);
const double v = -results.log_softmax;
local_nll += v;
local_nll2 += v*v;
@@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
break;
}
lock.unlock();
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
}
@@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab);
}
}