From b9daa401d76221a568dd8a9ce2497b80e469d378 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 17 Mar 2025 12:04:52 +0200 Subject: [PATCH] Be able to compute for more than 65535 tokens On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits. --- examples/perplexity/perplexity.cpp | 6 +++--- ggml/src/ggml-cuda/concat.cu | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f0..95aedce6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -166,7 +166,7 @@ static void process_logits( break; } lock.unlock(); - const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]); const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; @@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, break; } lock.unlock(); - const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]); local_nll += v; local_nll2 += v*v; } @@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab); } } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 6dbdb352..01004564 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -81,6 +81,19 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + if (dim == 0 && ne1 >= 65536) { + int64_t nstep = (ne1 + 32767)/32768; + for (int64_t istep = 0; istep < nstep; ++istep) { + int64_t i1 = 32768*istep; + int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1; + dim3 gridDim(num_blocks, n1, ne2); + const float * xi = x + i1*ne00; + const float * yi = y + i1*(ne0 - ne00); + float * dst_i = dst + i1*ne0; + concat_f32_dim0<<>>(xi, yi, dst_i, ne0, ne00); + } + return; + } dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { concat_f32_dim0<<>>(x, y, dst, ne0, ne00); @@ -168,6 +181,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) && src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; @@ -200,6 +217,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->type == GGML_TYPE_F32); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data;