mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 08:59:30 +00:00
Be able to compute for more than 65535 tokens
On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits.
This commit is contained in:
@@ -166,7 +166,7 @@ static void process_logits(
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
||||
const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]);
|
||||
const double v = -results.log_softmax;
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
@@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]);
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
}
|
||||
@@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
|
||||
if (num_batches > 1 && n_outputs > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,19 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
|
||||
|
||||
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
|
||||
if (dim == 0 && ne1 >= 65536) {
|
||||
int64_t nstep = (ne1 + 32767)/32768;
|
||||
for (int64_t istep = 0; istep < nstep; ++istep) {
|
||||
int64_t i1 = 32768*istep;
|
||||
int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1;
|
||||
dim3 gridDim(num_blocks, n1, ne2);
|
||||
const float * xi = x + i1*ne00;
|
||||
const float * yi = y + i1*(ne0 - ne00);
|
||||
float * dst_i = dst + i1*ne0;
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00);
|
||||
}
|
||||
return;
|
||||
}
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
if (dim == 0) {
|
||||
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
|
||||
@@ -168,6 +181,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
|
||||
src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
|
||||
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
|
||||
// GGML_ABORT("fatal error");
|
||||
//}
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
@@ -200,6 +217,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
|
||||
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
|
||||
// GGML_ABORT("fatal error");
|
||||
//}
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user