Files
ktransformers/kt-kernel/operators/kvcache/kvcache_attn.cpp
2025-11-03 15:19:52 +08:00

1839 lines
102 KiB
C++

/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include <chrono>
#include <cmath>
#include "ggml-impl.h"
#include "kvcache.h"
#include "llamafile/sgemm.h"
void KVCache::attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);
int head_id =
(task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;
int block_id = task_id % max_block_num_after_retrieval_;
int thread_id = WorkerPool::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table_after_retrieval_kvhead_[batch_id][block_id][head_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0) return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
// move the results to output and attn_lse
uint16_t* output_data = reinterpret_cast<uint16_t*>(output);
float* attn_lse_data = attn_lse;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_idx,
// diff.count());
}
void KVCache::attention_layer_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);
int head_id =
(task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;
int block_id = task_id % max_block_num_after_retrieval_;
int thread_id = WorkerPool::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table_after_retrieval_[batch_id][block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0) return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
// move the results to output and attn_lse
uint16_t* output_data = reinterpret_cast<uint16_t*>(output);
float* attn_lse_data = attn_lse;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_id_,
// diff.count());
}
void KVCache::attn(const ggml_fp16_t* q_in, ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx,
int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,
int pick_block_num, int init_block_num, int local_block_num, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
batch_size = batch_size * q_len;
const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);
quantize_q_(q_in_data, batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,
batch_size, layer_idx, cache_seqlens, max_block_num, backend);
attention_layer_(q_in_data, output, attn_lse, batch_size, backend);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,
batch_size, layer_idx, cache_seqlens, max_block_num, backend);
attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention: %f s\n", layer_idx,
// diff.count());
}
void KVCache::attn_with_kvcache(const ggml_fp16_t* q_in, const ggml_fp16_t* k_in, const ggml_fp16_t* v_in,
ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx, int q_len,
int batch_size, int max_block_num, int* block_table, int* cache_seqlens, int topk,
int local, WorkerPool* backend) {
// printf("attn_with_kvcache start\n");
assert(q_len == 1);
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size, max_block_num, cache_seqlens, q_len, backend);
// printf("update finished.\n");
// cache_seqlens memory is modified.
for (int i = 0; i < batch_size; i++) {
cache_seqlens[i] += q_len;
}
int init_block_num = 1;
if (config_.block_len <= 32) {
init_block_num = 64 / config_.block_len;
}
attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table,
cache_seqlens, topk, init_block_num, local, backend);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of computing attention with kvcache: %f s\n",
// layer_idx, diff.count());
}
void KVCache::quantize_q_(const uint16_t* q_in_data, int batch_size) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
// quantize q
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
q_fp32_[batch_idx][i][j] =
GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +
i * n_gqa_ * config_.head_dim + j]);
}
}
} else {
// quantize q
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
q_fp32[j] = GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +
i * n_gqa_ * config_.head_dim + j]);
}
quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(), n_gqa_ * config_.head_dim);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("time of quantizing q: %f s\n",
// std::chrono::duration<double>(end - start).count());
}
void KVCache::attn_initialize_layer_(int batch_size, int layer_idx, int* block_table, int& max_block_num,
int* cache_seqlens) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
// initialize output_fp32_ and attn_lse_
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_fp32_[batch_idx][i][j] = 0;
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_[batch_idx][i][j] = 0;
}
}
// clear top_similar_block_
while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();
}
// get block_table_before_retrieval_ and cache_seqlens_
if (block_table == nullptr) {
max_block_num = past_block_num_[layer_idx];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (cache_total_len_ != 0)
cache_seqlens_[batch_idx] = cache_total_len_;
else
cache_seqlens_[batch_idx] = max_block_num * config_.block_len;
for (int i = 0; i < max_block_num; i++) {
block_table_before_retrieval_[batch_idx][i] = i;
block_similar_[batch_idx][i] = 0;
}
}
} else {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
for (int i = 0; i < max_block_num; i++) {
block_table_before_retrieval_[batch_idx][i] = block_table[batch_idx * max_block_num + i];
block_similar_[batch_idx][i] = 0;
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of initializing attention: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_block_similarity_layer_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int* cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
if (batch_size == 1 && config_.anchor_num == 1) { // TODO: improve batch_size > 1
for (int batch_id = 0; batch_id < batch_size; batch_id++) {
if (q_len == 1) {
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
avg_q[batch_id][j] =
GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j]);
avg_q_fp16[batch_id][j] = q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j];
}
} else {
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
avg_q[batch_id][j] = 0;
}
for (int i = 0; i < q_len; i++) {
for (int j = 0; j < config_.head_dim; j++) {
avg_q[batch_id][j] += GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
i * config_.q_head_num * config_.head_dim + j]);
}
}
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
avg_q[batch_id][j] /= q_len;
avg_q_fp16[batch_id][j] = GGML_FP32_TO_FP16(avg_q[batch_id][j]);
}
}
int seq_len = cache_seqlens_[batch_id];
int block_num = (seq_len / config_.block_len) - local_block_num - init_block_num;
if (block_num <= 0) {
continue;
}
bool is_seq = true;
for (int i = init_block_num + 1; i < (seq_len / config_.block_len) - local_block_num; i++) {
if (block_table_before_retrieval_[batch_id][i] != block_table_before_retrieval_[batch_id][i - 1] + 1) {
is_seq = false;
break;
}
}
if (is_seq) {
int nth = backend->get_thread_num();
backend->do_work_stealing_job(
nth, nullptr,
[&](int task_id) {
int ith = task_id;
bool ok = llamafile_sgemm(
block_num, 1, config_.q_head_num * config_.head_dim,
anchor_.data() +
(layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][init_block_num]) *
config_.anchor_num * config_.q_head_num * config_.head_dim,
config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),
config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + init_block_num, block_num,
ith, nth, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
},
nullptr);
} else {
backend->do_work_stealing_job(
block_num, nullptr,
[&](int task_id) {
int block_id = task_id + init_block_num;
int block_idx = block_table_before_retrieval_[batch_id][block_id];
bool ok = llamafile_sgemm(
1, 1, config_.q_head_num * config_.head_dim,
anchor_.data() +
(layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][block_idx]) *
config_.anchor_num * config_.q_head_num * config_.head_dim,
config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),
config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + block_id, 1, 0, 1,
GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
},
nullptr);
}
}
} else {
backend->do_work_stealing_job(
batch_size * max_block_num, nullptr,
[&](int task_id) {
int batch_id = task_id / max_block_num;
int block_id = task_id % max_block_num;
int seq_len = cache_seqlens_[batch_id];
if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {
return;
}
int block_idx = block_table_before_retrieval_[batch_id][block_id];
float sim = 0;
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int i = 0; i < config_.head_dim; i++) {
float q_i = 0, qa_i = std::numeric_limits<float>::lowest();
for (int q_id = 0; q_id < q_len; q_id++) {
q_i += GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);
}
q_i /= q_len;
for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {
qa_i = std::max(
qa_i,
GGML_FP16_TO_FP32(
anchor_[(long long)layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *
q_i);
}
sim += qa_i;
}
}
block_similar_[batch_id][block_id] = sim;
},
nullptr);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
// diff.count());
}
void KVCache::select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num,
int local_block_num, int pick_block_num) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {
block_table_after_retrieval_[batch_idx].swap(block_table_before_retrieval_[batch_idx]);
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
continue;
}
for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
block_id++) {
top_similar_block_[batch_idx].push(
std::make_pair(block_similar_[batch_idx][block_id], block_table_before_retrieval_[batch_idx][block_id]));
if (top_similar_block_[batch_idx].size() > pick_block_num) {
top_similar_block_[batch_idx].pop();
}
}
int i = 0;
for (; i < init_block_num; i++) {
block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx][i];
}
while (!top_similar_block_[batch_idx].empty()) {
block_table_after_retrieval_[batch_idx][i] = top_similar_block_[batch_idx].top().second;
top_similar_block_[batch_idx].pop();
i++;
}
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
block_table_after_retrieval_[batch_idx][i] =
block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num +
i - init_block_num - pick_block_num];
}
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
block_table_after_retrieval_[batch_idx][i] =
block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)];
cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
i++;
} else {
cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
}
for (int j = 0; j < i; j++) {
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j] =
block_table_after_retrieval_[batch_idx][j];
}
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = i;
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of selecting blocks: %f s\n", layer_idx,
// diff.count());
}
// retrieval kvcache, get the init_block_num block at beginning, top
// pick_block_num similar and last local_block_num blocks. Each task
// calculates the simlarity of a certain block with the query, then push
// the block into the priority queue. Finally, the required blocks are
// pushed into the block_table_after_retrieval_.
void KVCache::retrieval_kvcache_layer_(const uint16_t* q_in_data, int init_block_num, int local_block_num,
int pick_block_num, int q_len, int generate_token_idx, int batch_size,
int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
max_block_num_after_retrieval_ = 0;
if (pick_block_num != -1 &&
(generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {
max_block_num_after_retrieval_ = max_block_num;
block_table_after_retrieval_.swap(block_table_before_retrieval_);
} else {
max_block_num_after_retrieval_ =
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
block_table_after_retrieval_[batch_idx][i] =
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i];
}
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;
int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
int last_block_idx = block_table_before_retrieval_[batch_idx][cache_seqlens[batch_idx] / config_.block_len];
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1] =
last_block_idx;
block_table_after_retrieval_[batch_idx][x - 1] = last_block_idx;
}
cache_seqlens_[batch_idx] =
(cache_seqlens_[batch_idx] % config_.block_len) +
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] * config_.block_len -
config_.block_len;
}
}
} else if (pick_block_num != -1) {
max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);
calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,
init_block_num, local_block_num, pick_block_num, backend);
select_block_layer_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);
} else {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
max_block_num_after_retrieval_ = max_block_num;
block_table_after_retrieval_.swap(block_table_before_retrieval_);
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_sparsity_layer_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,
int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend
) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;
int block_id = task_id % max_block_num;
int thread_id = WorkerPool::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0) return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
}
for (int i = 0; i < n_gqa_; i++) {
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
int block_idx = block_table_after_retrieval_[i][j];
for (int k = 0; k < config_.q_head_num; k++) {
attn_sparsity[i * config_.q_head_num + k] +=
std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
// diff.count());
}
void KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx, int* block_table, int& max_block_num,
int* cache_seqlens) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
// initialize output_fp32_ and attn_lse_
for (int i = 0; i < config_.kv_head_num; i++) {
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
output_fp32_[batch_idx][i][j] = 0;
}
for (int j = 0; j < n_gqa_; j++) {
attn_lse_[batch_idx][i][j] = 0;
}
}
// clear top_similar_block_
while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();
}
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_before_retrieval_kvhead_[batch_idx][i][j] = block_table[batch_idx * max_block_num + i];
block_similar_kv_head_[batch_idx][i][j] = 0;
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of initializing attn: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::retrieval_kvcache_kvhead_(const uint16_t* q_in_data, int init_block_num, int local_block_num,
int pick_block_num, int q_len, int generate_token_idx, int batch_size,
int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
max_block_num_after_retrieval_ = 0;
if (pick_block_num != -1 &&
(generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {
max_block_num_after_retrieval_ = max_block_num;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
block_table_before_retrieval_kvhead_[batch_idx][i][j];
}
}
}
} else {
max_block_num_after_retrieval_ =
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i]
[j];
}
}
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;
int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
for (int i = 0; i < config_.kv_head_num; i++) {
int last_block_idx =
block_table_before_retrieval_kvhead_[batch_idx][cache_seqlens[batch_idx] / config_.block_len][i];
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1]
[i] = last_block_idx;
block_table_after_retrieval_kvhead_[batch_idx][x - 1][i] = last_block_idx;
}
}
cache_seqlens_[batch_idx] = std::min(
cache_seqlens_[batch_idx], (cache_seqlens_[batch_idx] % config_.block_len) +
(init_block_num + pick_block_num + local_block_num) * config_.block_len);
}
}
} else if (pick_block_num != -1) {
max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);
calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,
init_block_num, local_block_num, pick_block_num, backend);
select_block_kvhead_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);
} else {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
max_block_num_after_retrieval_ = max_block_num;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];
}
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
// std::chrono::duration<double>(end - start).count());
}
void KVCache::calculate_sparsity_kvhead_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,
int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
seq_len_ = config_.block_len;
backend->do_work_stealing_job(
batch_size * config_.kv_head_num * max_block_num,
[&](int thread_id) {
thread_cur_head_idx_[thread_id].first = -1;
thread_cur_head_idx_[thread_id].second = -1;
},
[&](int task_id) {
int batch_id = task_id / (config_.kv_head_num * max_block_num);
int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;
int block_id = task_id % max_block_num;
int thread_id = WorkerPool::thread_local_id;
// If the block is out of the sequence length, skip it.
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
return;
}
int block_idx = block_table[batch_id * max_block_num + block_id];
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
if (seq_len == 0) return;
// Prepare the attention mask for the last block.
int full_blocks = seq_len / 8;
int remaining_bits = seq_len % 8;
// Fill full blocks with 1s
for (int i = 0; i < full_blocks; ++i) {
thread_local_attn_mask_[thread_id][i] = 0xFF;
}
// Fill the remaining bits in the next block
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
} else {
thread_local_attn_mask_[thread_id][full_blocks] = 0;
}
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
thread_local_attn_mask_[thread_id][i] = 0;
}
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
} else {
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
attn_with_kvcache_one_block_(
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
head_id * n_gqa_ * config_.head_dim],
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
sin_.data());
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
thread_local_output_q8_0_[thread_id].data(),
thread_local_attn_lse_[thread_id].data(),
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
}
}
for (int i = 0; i < n_gqa_; i++) {
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];
}
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
thread_local_cur_attn_lse_[thread_id][i]));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
}
} else {
if (cur_batch_idx != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim,
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim,
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
thread_cur_head_idx_[thread_id].first = batch_id;
thread_cur_head_idx_[thread_id].second = head_id;
for (int i = 0; i < n_gqa_; i++) {
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
}
}
}
},
// Merge the results of the remaining blocks.
[&](int thread_id) {
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
int cur_head_id = thread_cur_head_idx_[thread_id].second;
if (cur_head_id != -1) {
mutex_[cur_batch_idx][cur_head_id]->lock();
for (int i = 0; i < n_gqa_; i++) {
float new_attn_lse;
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
continue;
}
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
attn_lse_[cur_batch_idx][cur_head_id][i]));
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
for (int j = 0; j < config_.head_dim; j++) {
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
}
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
}
mutex_[cur_batch_idx][cur_head_id]->unlock();
}
});
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
for (int k = 0; k < config_.q_head_num; k++) {
int block_idx = block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];
attn_sparsity[i * config_.q_head_num + k] +=
std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);
}
}
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
// diff.count());
}
void KVCache::calculate_block_similarity_kvhead_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,
int max_block_num, int* cache_seqlens, int init_block_num,
int local_block_num, int pick_block_num, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
backend->do_work_stealing_job(
batch_size * max_block_num, nullptr,
[&](int task_id) {
int batch_id = task_id / max_block_num;
int block_id = task_id % max_block_num;
int seq_len = cache_seqlens_[batch_id];
if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {
return;
}
int block_idx = block_table_before_retrieval_kvhead_[batch_id][block_id][0];
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
for (int i = 0; i < config_.head_dim; i++) {
float q_i = 0, qa_i = std::numeric_limits<float>::lowest();
for (int q_id = 0; q_id < q_len; q_id++) {
q_i += GGML_FP16_TO_FP32(
q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);
}
q_i /= q_len;
for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {
qa_i = std::max(
qa_i,
GGML_FP16_TO_FP32(
anchor_[layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *
config_.head_dim +
block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +
anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *
q_i);
}
block_similar_kv_head_[batch_id][block_id][head_id / n_gqa_] += qa_i;
}
}
},
nullptr);
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
// diff.count());
}
void KVCache::select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num,
int local_block_num, int pick_block_num) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cache_len_after_retrieval = 0;
if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
for (int i = 0; i < max_block_num; i++) {
for (int j = 0; j < config_.kv_head_num; j++) {
block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];
}
}
continue;
}
for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {
for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
block_id++) {
top_similar_block_[batch_idx].push(
std::make_pair(block_similar_kv_head_[batch_idx][block_id][head_id],
block_table_before_retrieval_kvhead_[batch_idx][block_id][head_id]));
if (top_similar_block_[batch_idx].size() > pick_block_num) {
top_similar_block_[batch_idx].pop();
}
}
int i = 0;
for (; i < init_block_num; i++) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_[batch_idx][i][head_id];
}
while (!top_similar_block_[batch_idx].empty()) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = top_similar_block_[batch_idx].top().second;
top_similar_block_[batch_idx].pop();
i++;
}
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) -
local_block_num + i - init_block_num - pick_block_num]
[head_id];
}
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)][head_id];
cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
i++;
} else {
cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
}
for (int j = 0; j < i; j++) {
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j]
[head_id] = block_table_after_retrieval_kvhead_[batch_idx][j][head_id];
}
}
cache_seqlens_[batch_idx] = cache_len_after_retrieval;
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] =
(cache_len_after_retrieval + config_.block_len - 1) / config_.block_len;
}
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
// printf("layer %d time of selecting block: %f s\n", layer_idx,
// diff.count())
}
void KVCache::get_attn_sparsity(const ggml_fp16_t* q_in, float* attn_sparsity, int layer_idx, int generate_token_idx,
int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,
int* block_table_origin, int* cache_seqlens_origin, int max_block_num_origin, int topk,
int local, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
layer_id_ = layer_idx;
int thread_num = backend->get_thread_num();
batch_size = 1;
const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);
quantize_q_(q_in_data, batch_size);
if (config_.retrieval_type == RetrievalType::LAYER) {
attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens,
max_block_num, backend);
calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,
cache_seqlens_origin, backend);
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx,
cache_seqlens, max_block_num, backend);
calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,
cache_seqlens_origin, backend);
}
}
void KVCache::attn_with_kvcache_one_block_(int head_dim, int bsz,
ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
// [bsz, head_dim]
// Quantization is always on the head_dim dimension (per_token). If
// head_dim % 32 != 0, an error will be raised. The size must be bsz *
// head_dim/32 * qtype_size.
const void* q,
int past_kv_len, int past_kv_offset,
bool is_full_attn, // true indicates a full 1 mask
// If is_full_attn = false, a bit matrix representing the mask is
// passed. [bsz, past_kv_len]
const uint8_t* attn_mask,
ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
// q4_0, q8_0
int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
// error
// [seq_len, head_dim]
// If quant_type == 0, head_dim % 32 must be 0.
// If quant_type == 1, seq_len % 32 must be 0.
const void* k_cache,
// k_anchor_type must be fp16
int num_k_anchor, // num_k_anchor == 0 indicates no anchor
// [num_k_anchor, head_dim]
const void* k_cache_anchors,
// Each token is associated with the nearest previous position's anchor,
// with the same distance.
const int* k_cache_anchor_pos,
// v_cache similar to k_cache
ggml_type v_type, int v_quant_type,
// [head_dim, seq_len]
const void* v_cache, int num_v_anchor, const void* v_cache_anchors,
const int* v_cache_anchor_pos,
// Pre-allocated buffer for intermediate calculations [bsz,
// past_kv_len]. No malloc is performed inside this function.
float* attn_score,
// Output: [bsz, head_dim], with the same type as q_type
void* output,
// [bsz]
float* lse,
// Pre-allocated temporary buffer with sufficient size:
// (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
// head_dim + past_kv_len * head_dim / 32) bytes.
void* draft,
// Apply rotary embedding online
const int* rotary_angle, const void* rotary_cos, const void* rotary_sin
// rotary_cos=None,
// rotary_sin=None,
// cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
// cache_batch_idx: Optional[torch.Tensor] = None,
// rotary_interleaved=True,
// // Not supported for now
// window_size=(-1, -1), # -1 means infinite context window
// alibi_slopes=None,
) {
assert(head_dim % 32 == 0);
assert(k_quant_type == 0);
assert(v_quant_type == 1);
assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);
if (q_type == GGML_TYPE_F16) {
assert(k_type == GGML_TYPE_F16);
assert(v_type == GGML_TYPE_F16);
// attn = q * k + q * k_anchor
// TODO: anchor
assert(num_k_anchor == 0);
if (rotary_angle != nullptr) {
ggml_fp16_t* k_cache_with_rope_fp16 =
(reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim);
// dequant k_cache and apply rope
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
// k(i)cos(i) -> k_rope(i)
// k(i)sin(i+l) -> k_rope(i+l)
// k(i)cos(i) -> k_rope(i)
// -k(i)sin(i-l) -> k_rope(i-l)
std::vector<float> block_fp32(32);
for (int k = 0; k < past_kv_len; k++) {
int angle = rotary_angle[k];
for (int l = 0; l < head_dim / 32; l++) {
for (int m = 0; m < 32; m++) {
float x = GGML_FP16_TO_FP32(((ggml_fp16_t*)k_cache)[k * head_dim + l * 32 + m]);
float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);
float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);
if (l * 32 + m < head_dim / 2) {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(x * cos_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] = GGML_FP32_TO_FP16(-x * sin_val);
} else {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]) + x * sin_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] = GGML_FP32_TO_FP16(
GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2]) - x * cos_val);
}
}
}
}
llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache_with_rope_fp16, head_dim, (ggml_fp16_t*)q,
head_dim, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
} else {
bool ok = llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache, head_dim, (ggml_fp16_t*)q, head_dim,
attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
}
// attn = attn * scale
float scale_factor = 1.0 / std::sqrt(float(head_dim));
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
// attn = attn & mask
if (!is_full_attn) {
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < past_kv_len; j++) {
int index = i * past_kv_len + j;
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
attn_score[index] = std::numeric_limits<float>::lowest();
}
}
}
}
// attn = softmax(attn)
for (int i = 0; i < bsz; i++) {
float sum_exp = 0;
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);
sum_exp += attn_score[i * past_kv_len + j];
}
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] /= sum_exp;
}
if (lse != nullptr) {
lse[i] = std::log(sum_exp);
}
}
// output = attn * v + attn * v_anchor
// std::vector<float> sum(bsz * head_dim);
float* sum =
reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);
// float* attn_score_fp16(bsz, past_kv_len)
ggml_fp16_t* attn_score_fp16 = (reinterpret_cast<ggml_fp16_t*>(reinterpret_cast<char*>(draft) +
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim));
for (int i = 0; i < bsz * past_kv_len; i++) {
attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);
}
// TODO: anchor
assert(num_v_anchor == 0);
bool ok = llamafile_sgemm(head_dim, bsz, past_kv_len, (ggml_fp16_t*)v_cache, past_kv_len,
(ggml_fp16_t*)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE,
v_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (!ok) {
printf("llamafile_sgemm failed\n");
}
// copy to output
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < head_dim; j++) {
((float*)output)[i * head_dim + j] = sum[i * head_dim + j];
}
}
} else {
assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);
assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);
// attn = q * k + q * k_anchor
// TODO: anchor
assert(num_k_anchor == 0);
if (rotary_angle != nullptr) {
ggml_fp16_t* k_cache_with_rope_fp16 =
(reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim);
block_q4_0* k_cache_with_rope_q4 =
(reinterpret_cast<block_q4_0*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
sizeof(float) * bsz * head_dim) +
sizeof(ggml_fp16_t) * bsz * head_dim;
// dequant k_cache and apply rope
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
// k(i)cos(i) -> k_rope(i)
// k(i)sin(i+l) -> k_rope(i+l)
// k(i)cos(i) -> k_rope(i)
// -k(i)sin(i-l) -> k_rope(i-l)
std::vector<float> block_fp32(32);
for (int k = 0; k < past_kv_len; k++) {
int angle = rotary_angle[k];
for (int l = 0; l < head_dim / 32; l++) {
block_q4_0 block = ((block_q4_0*)k_cache)[k * head_dim / 32 + l];
dequantize_row_q4_0(&block, block_fp32.data(), 32);
for (int m = 0; m < 32; m++) {
float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);
float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);
if (l * 32 + m < head_dim / 2) {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] =
GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);
} else {
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] += GGML_FP32_TO_FP16(block_fp32[m] * sin_val);
k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] -=
GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
}
}
}
}
// quantize k_cache_with_rope_fp16
for (int k = 0; k < past_kv_len; k++) {
for (int l = 0; l < head_dim / 32; l++) {
for (int m = 0; m < 32; m++) {
block_fp32[m] = GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);
}
quantize_row_q4_0(block_fp32.data(), &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);
}
}
llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache_with_rope_q4, head_dim / 32, (block_q8_0*)q,
head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
} else {
llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache, head_dim / 32, (block_q8_0*)q,
head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
}
// attn = attn * scale
float scale_factor = 1.0 / std::sqrt(float(head_dim));
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
// attn = attn & mask
if (!is_full_attn) {
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < past_kv_len; j++) {
int index = i * past_kv_len + j;
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
attn_score[index] = std::numeric_limits<float>::lowest();
}
}
}
}
// attn = softmax(attn)
for (int i = 0; i < bsz; i++) {
float sum_exp = 0;
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);
sum_exp += attn_score[i * past_kv_len + j];
}
for (int j = 0; j < past_kv_len; j++) {
attn_score[i * past_kv_len + j] /= sum_exp;
}
if (lse != nullptr) {
lse[i] = std::log(sum_exp);
}
}
// output = attn * v + attn * v_anchor
// std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);
block_q8_0* attn_q8_0 = reinterpret_cast<block_q8_0*>(draft);
quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);
// std::vector<float> sum(bsz * head_dim);
float* sum =
reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);
// TODO: anchor
assert(num_v_anchor == 0);
llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0*)v_cache, past_kv_len / 32, attn_q8_0,
past_kv_len / 32, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_Q8_0,
GGML_TYPE_F32, GGML_PREC_DEFAULT);
quantize_row_q8_0(sum, (block_q8_0*)output, bsz * head_dim);
}
}