mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-24 22:55:49 +00:00
1839 lines
102 KiB
C++
1839 lines
102 KiB
C++
/**
|
|
* @Description :
|
|
* @Author : Jianwei Dong
|
|
* @Date : 2024-08-26 22:47:06
|
|
* @Version : 1.0.0
|
|
* @LastEditors : Jianwei Dong
|
|
* @LastEditTime : 2024-08-26 22:47:06
|
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
**/
|
|
|
|
#include <chrono>
|
|
#include <cmath>
|
|
|
|
#include "ggml-impl.h"
|
|
#include "kvcache.h"
|
|
#include "llamafile/sgemm.h"
|
|
|
|
void KVCache::attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
|
|
WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
seq_len_ = config_.block_len;
|
|
|
|
backend->do_work_stealing_job(
|
|
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
|
|
[&](int thread_id) {
|
|
thread_cur_head_idx_[thread_id].first = -1;
|
|
thread_cur_head_idx_[thread_id].second = -1;
|
|
},
|
|
[&](int task_id) {
|
|
int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);
|
|
int head_id =
|
|
(task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;
|
|
int block_id = task_id % max_block_num_after_retrieval_;
|
|
int thread_id = WorkerPool::thread_local_id;
|
|
|
|
// If the block is out of the sequence length, skip it.
|
|
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
|
|
return;
|
|
}
|
|
int block_idx = block_table_after_retrieval_kvhead_[batch_id][block_id][head_id];
|
|
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
|
|
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
|
|
if (seq_len == 0) return;
|
|
|
|
// Prepare the attention mask for the last block.
|
|
int full_blocks = seq_len / 8;
|
|
int remaining_bits = seq_len % 8;
|
|
// Fill full blocks with 1s
|
|
for (int i = 0; i < full_blocks; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0xFF;
|
|
}
|
|
// Fill the remaining bits in the next block
|
|
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
|
|
} else {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = 0;
|
|
}
|
|
|
|
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0;
|
|
}
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
|
|
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
} else {
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
|
|
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
|
|
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
|
|
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
}
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
|
|
thread_local_cur_attn_lse_[thread_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
|
|
}
|
|
} else {
|
|
if (cur_batch_idx != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
thread_cur_head_idx_[thread_id].first = batch_id;
|
|
thread_cur_head_idx_[thread_id].second = head_id;
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
}
|
|
}
|
|
},
|
|
// Merge the results of the remaining blocks.
|
|
[&](int thread_id) {
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (cur_head_id != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse;
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
});
|
|
// move the results to output and attn_lse
|
|
uint16_t* output_data = reinterpret_cast<uint16_t*>(output);
|
|
float* attn_lse_data = attn_lse;
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =
|
|
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
|
|
}
|
|
for (int j = 0; j < n_gqa_; j++) {
|
|
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of computing attention: %f s\n", layer_idx,
|
|
// diff.count());
|
|
}
|
|
|
|
void KVCache::attention_layer_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
|
|
WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
seq_len_ = config_.block_len;
|
|
backend->do_work_stealing_job(
|
|
batch_size * config_.kv_head_num * max_block_num_after_retrieval_,
|
|
[&](int thread_id) {
|
|
thread_cur_head_idx_[thread_id].first = -1;
|
|
thread_cur_head_idx_[thread_id].second = -1;
|
|
},
|
|
[&](int task_id) {
|
|
int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);
|
|
int head_id =
|
|
(task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;
|
|
int block_id = task_id % max_block_num_after_retrieval_;
|
|
int thread_id = WorkerPool::thread_local_id;
|
|
// If the block is out of the sequence length, skip it.
|
|
if (cache_seqlens_[batch_id] / config_.block_len < block_id) {
|
|
return;
|
|
}
|
|
int block_idx = block_table_after_retrieval_[batch_id][block_id];
|
|
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
|
|
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
|
|
if (seq_len == 0) return;
|
|
|
|
// Prepare the attention mask for the last block.
|
|
int full_blocks = seq_len / 8;
|
|
int remaining_bits = seq_len % 8;
|
|
|
|
// Fill full blocks with 1s
|
|
for (int i = 0; i < full_blocks; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0xFF;
|
|
}
|
|
// Fill the remaining bits in the next block
|
|
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
|
|
} else {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = 0;
|
|
}
|
|
|
|
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0;
|
|
}
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
|
|
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
} else {
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
|
|
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
|
|
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
|
|
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
}
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
|
|
thread_local_cur_attn_lse_[thread_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
|
|
}
|
|
} else {
|
|
if (cur_batch_idx != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
thread_cur_head_idx_[thread_id].first = batch_id;
|
|
thread_cur_head_idx_[thread_id].second = head_id;
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
}
|
|
}
|
|
},
|
|
// Merge the results of the remaining blocks.
|
|
[&](int thread_id) {
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (cur_head_id != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse;
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
});
|
|
|
|
// move the results to output and attn_lse
|
|
uint16_t* output_data = reinterpret_cast<uint16_t*>(output);
|
|
float* attn_lse_data = attn_lse;
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =
|
|
GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);
|
|
}
|
|
for (int j = 0; j < n_gqa_; j++) {
|
|
attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];
|
|
}
|
|
}
|
|
}
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of computing attention: %f s\n", layer_id_,
|
|
// diff.count());
|
|
}
|
|
|
|
void KVCache::attn(const ggml_fp16_t* q_in, ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx,
|
|
int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,
|
|
int pick_block_num, int init_block_num, int local_block_num, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
layer_id_ = layer_idx;
|
|
batch_size = batch_size * q_len;
|
|
|
|
const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);
|
|
|
|
quantize_q_(q_in_data, batch_size);
|
|
if (config_.retrieval_type == RetrievalType::LAYER) {
|
|
attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
|
|
retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,
|
|
batch_size, layer_idx, cache_seqlens, max_block_num, backend);
|
|
attention_layer_(q_in_data, output, attn_lse, batch_size, backend);
|
|
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
|
|
attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
|
|
retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,
|
|
batch_size, layer_idx, cache_seqlens, max_block_num, backend);
|
|
attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of computing attention: %f s\n", layer_idx,
|
|
// diff.count());
|
|
}
|
|
|
|
void KVCache::attn_with_kvcache(const ggml_fp16_t* q_in, const ggml_fp16_t* k_in, const ggml_fp16_t* v_in,
|
|
ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx, int q_len,
|
|
int batch_size, int max_block_num, int* block_table, int* cache_seqlens, int topk,
|
|
int local, WorkerPool* backend) {
|
|
// printf("attn_with_kvcache start\n");
|
|
assert(q_len == 1);
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
|
|
layer_id_ = layer_idx;
|
|
|
|
update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size, max_block_num, cache_seqlens, q_len, backend);
|
|
// printf("update finished.\n");
|
|
|
|
// cache_seqlens memory is modified.
|
|
for (int i = 0; i < batch_size; i++) {
|
|
cache_seqlens[i] += q_len;
|
|
}
|
|
int init_block_num = 1;
|
|
if (config_.block_len <= 32) {
|
|
init_block_num = 64 / config_.block_len;
|
|
}
|
|
|
|
attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table,
|
|
cache_seqlens, topk, init_block_num, local, backend);
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of computing attention with kvcache: %f s\n",
|
|
// layer_idx, diff.count());
|
|
}
|
|
|
|
void KVCache::quantize_q_(const uint16_t* q_in_data, int batch_size) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
// quantize q
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
q_fp32_[batch_idx][i][j] =
|
|
GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
i * n_gqa_ * config_.head_dim + j]);
|
|
}
|
|
}
|
|
} else {
|
|
// quantize q
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
q_fp32[j] = GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
i * n_gqa_ * config_.head_dim + j]);
|
|
}
|
|
quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
}
|
|
}
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// printf("time of quantizing q: %f s\n",
|
|
// std::chrono::duration<double>(end - start).count());
|
|
}
|
|
void KVCache::attn_initialize_layer_(int batch_size, int layer_idx, int* block_table, int& max_block_num,
|
|
int* cache_seqlens) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
// initialize output_fp32_ and attn_lse_
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
output_fp32_[batch_idx][i][j] = 0;
|
|
}
|
|
for (int j = 0; j < n_gqa_; j++) {
|
|
attn_lse_[batch_idx][i][j] = 0;
|
|
}
|
|
}
|
|
// clear top_similar_block_
|
|
|
|
while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();
|
|
}
|
|
|
|
// get block_table_before_retrieval_ and cache_seqlens_
|
|
if (block_table == nullptr) {
|
|
max_block_num = past_block_num_[layer_idx];
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
if (cache_total_len_ != 0)
|
|
cache_seqlens_[batch_idx] = cache_total_len_;
|
|
else
|
|
cache_seqlens_[batch_idx] = max_block_num * config_.block_len;
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
block_table_before_retrieval_[batch_idx][i] = i;
|
|
block_similar_[batch_idx][i] = 0;
|
|
}
|
|
}
|
|
} else {
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
block_table_before_retrieval_[batch_idx][i] = block_table[batch_idx * max_block_num + i];
|
|
block_similar_[batch_idx][i] = 0;
|
|
}
|
|
}
|
|
}
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// printf("layer %d time of initializing attention: %f s\n", layer_idx,
|
|
// std::chrono::duration<double>(end - start).count());
|
|
}
|
|
|
|
void KVCache::calculate_block_similarity_layer_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,
|
|
int max_block_num, int* cache_seqlens, int init_block_num,
|
|
int local_block_num, int pick_block_num, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
|
|
if (batch_size == 1 && config_.anchor_num == 1) { // TODO: improve batch_size > 1
|
|
for (int batch_id = 0; batch_id < batch_size; batch_id++) {
|
|
if (q_len == 1) {
|
|
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
|
|
avg_q[batch_id][j] =
|
|
GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j]);
|
|
avg_q_fp16[batch_id][j] = q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j];
|
|
}
|
|
} else {
|
|
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
|
|
avg_q[batch_id][j] = 0;
|
|
}
|
|
for (int i = 0; i < q_len; i++) {
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
avg_q[batch_id][j] += GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
|
|
i * config_.q_head_num * config_.head_dim + j]);
|
|
}
|
|
}
|
|
for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {
|
|
avg_q[batch_id][j] /= q_len;
|
|
avg_q_fp16[batch_id][j] = GGML_FP32_TO_FP16(avg_q[batch_id][j]);
|
|
}
|
|
}
|
|
int seq_len = cache_seqlens_[batch_id];
|
|
int block_num = (seq_len / config_.block_len) - local_block_num - init_block_num;
|
|
if (block_num <= 0) {
|
|
continue;
|
|
}
|
|
bool is_seq = true;
|
|
for (int i = init_block_num + 1; i < (seq_len / config_.block_len) - local_block_num; i++) {
|
|
if (block_table_before_retrieval_[batch_id][i] != block_table_before_retrieval_[batch_id][i - 1] + 1) {
|
|
is_seq = false;
|
|
break;
|
|
}
|
|
}
|
|
if (is_seq) {
|
|
int nth = backend->get_thread_num();
|
|
backend->do_work_stealing_job(
|
|
nth, nullptr,
|
|
[&](int task_id) {
|
|
int ith = task_id;
|
|
bool ok = llamafile_sgemm(
|
|
block_num, 1, config_.q_head_num * config_.head_dim,
|
|
anchor_.data() +
|
|
(layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][init_block_num]) *
|
|
config_.anchor_num * config_.q_head_num * config_.head_dim,
|
|
config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),
|
|
config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + init_block_num, block_num,
|
|
ith, nth, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
if (!ok) {
|
|
printf("llamafile_sgemm failed\n");
|
|
}
|
|
},
|
|
nullptr);
|
|
} else {
|
|
backend->do_work_stealing_job(
|
|
block_num, nullptr,
|
|
[&](int task_id) {
|
|
int block_id = task_id + init_block_num;
|
|
int block_idx = block_table_before_retrieval_[batch_id][block_id];
|
|
bool ok = llamafile_sgemm(
|
|
1, 1, config_.q_head_num * config_.head_dim,
|
|
anchor_.data() +
|
|
(layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][block_idx]) *
|
|
config_.anchor_num * config_.q_head_num * config_.head_dim,
|
|
config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),
|
|
config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + block_id, 1, 0, 1,
|
|
GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
if (!ok) {
|
|
printf("llamafile_sgemm failed\n");
|
|
}
|
|
},
|
|
nullptr);
|
|
}
|
|
}
|
|
} else {
|
|
backend->do_work_stealing_job(
|
|
batch_size * max_block_num, nullptr,
|
|
[&](int task_id) {
|
|
int batch_id = task_id / max_block_num;
|
|
int block_id = task_id % max_block_num;
|
|
int seq_len = cache_seqlens_[batch_id];
|
|
|
|
if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {
|
|
return;
|
|
}
|
|
|
|
int block_idx = block_table_before_retrieval_[batch_id][block_id];
|
|
float sim = 0;
|
|
|
|
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
|
|
for (int i = 0; i < config_.head_dim; i++) {
|
|
float q_i = 0, qa_i = std::numeric_limits<float>::lowest();
|
|
for (int q_id = 0; q_id < q_len; q_id++) {
|
|
q_i += GGML_FP16_TO_FP32(
|
|
q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
|
|
q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);
|
|
}
|
|
q_i /= q_len;
|
|
for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {
|
|
qa_i = std::max(
|
|
qa_i,
|
|
GGML_FP16_TO_FP32(
|
|
anchor_[(long long)layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *
|
|
config_.head_dim +
|
|
block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +
|
|
anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *
|
|
q_i);
|
|
}
|
|
sim += qa_i;
|
|
}
|
|
}
|
|
block_similar_[batch_id][block_id] = sim;
|
|
},
|
|
nullptr);
|
|
}
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
|
|
// diff.count());
|
|
}
|
|
|
|
void KVCache::select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num,
|
|
int local_block_num, int pick_block_num) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {
|
|
block_table_after_retrieval_[batch_idx].swap(block_table_before_retrieval_[batch_idx]);
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
|
|
continue;
|
|
}
|
|
|
|
for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
|
|
block_id++) {
|
|
top_similar_block_[batch_idx].push(
|
|
std::make_pair(block_similar_[batch_idx][block_id], block_table_before_retrieval_[batch_idx][block_id]));
|
|
if (top_similar_block_[batch_idx].size() > pick_block_num) {
|
|
top_similar_block_[batch_idx].pop();
|
|
}
|
|
}
|
|
|
|
int i = 0;
|
|
for (; i < init_block_num; i++) {
|
|
block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx][i];
|
|
}
|
|
while (!top_similar_block_[batch_idx].empty()) {
|
|
block_table_after_retrieval_[batch_idx][i] = top_similar_block_[batch_idx].top().second;
|
|
top_similar_block_[batch_idx].pop();
|
|
i++;
|
|
}
|
|
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
|
|
block_table_after_retrieval_[batch_idx][i] =
|
|
block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num +
|
|
i - init_block_num - pick_block_num];
|
|
}
|
|
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
|
|
block_table_after_retrieval_[batch_idx][i] =
|
|
block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)];
|
|
cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
|
|
i++;
|
|
} else {
|
|
cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
|
|
}
|
|
for (int j = 0; j < i; j++) {
|
|
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j] =
|
|
block_table_after_retrieval_[batch_idx][j];
|
|
}
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = i;
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of selecting blocks: %f s\n", layer_idx,
|
|
// diff.count());
|
|
}
|
|
|
|
// retrieval kvcache, get the init_block_num block at beginning, top
|
|
// pick_block_num similar and last local_block_num blocks. Each task
|
|
// calculates the simlarity of a certain block with the query, then push
|
|
// the block into the priority queue. Finally, the required blocks are
|
|
// pushed into the block_table_after_retrieval_.
|
|
void KVCache::retrieval_kvcache_layer_(const uint16_t* q_in_data, int init_block_num, int local_block_num,
|
|
int pick_block_num, int q_len, int generate_token_idx, int batch_size,
|
|
int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
max_block_num_after_retrieval_ = 0;
|
|
if (pick_block_num != -1 &&
|
|
(generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {
|
|
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {
|
|
max_block_num_after_retrieval_ = max_block_num;
|
|
block_table_after_retrieval_.swap(block_table_before_retrieval_);
|
|
} else {
|
|
max_block_num_after_retrieval_ =
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
|
|
block_table_after_retrieval_[batch_idx][i] =
|
|
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i];
|
|
}
|
|
|
|
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;
|
|
int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
|
|
int last_block_idx = block_table_before_retrieval_[batch_idx][cache_seqlens[batch_idx] / config_.block_len];
|
|
selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1] =
|
|
last_block_idx;
|
|
block_table_after_retrieval_[batch_idx][x - 1] = last_block_idx;
|
|
}
|
|
cache_seqlens_[batch_idx] =
|
|
(cache_seqlens_[batch_idx] % config_.block_len) +
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] * config_.block_len -
|
|
config_.block_len;
|
|
}
|
|
}
|
|
} else if (pick_block_num != -1) {
|
|
max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);
|
|
calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,
|
|
init_block_num, local_block_num, pick_block_num, backend);
|
|
select_block_layer_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);
|
|
} else {
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
|
|
max_block_num_after_retrieval_ = max_block_num;
|
|
block_table_after_retrieval_.swap(block_table_before_retrieval_);
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
|
|
// std::chrono::duration<double>(end - start).count());
|
|
}
|
|
void KVCache::calculate_sparsity_layer_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,
|
|
int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend
|
|
|
|
) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
seq_len_ = config_.block_len;
|
|
backend->do_work_stealing_job(
|
|
batch_size * config_.kv_head_num * max_block_num,
|
|
[&](int thread_id) {
|
|
thread_cur_head_idx_[thread_id].first = -1;
|
|
thread_cur_head_idx_[thread_id].second = -1;
|
|
},
|
|
[&](int task_id) {
|
|
int batch_id = task_id / (config_.kv_head_num * max_block_num);
|
|
int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;
|
|
int block_id = task_id % max_block_num;
|
|
int thread_id = WorkerPool::thread_local_id;
|
|
// If the block is out of the sequence length, skip it.
|
|
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
|
|
return;
|
|
}
|
|
int block_idx = block_table[batch_id * max_block_num + block_id];
|
|
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
|
|
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
|
|
if (seq_len == 0) return;
|
|
|
|
// Prepare the attention mask for the last block.
|
|
int full_blocks = seq_len / 8;
|
|
int remaining_bits = seq_len % 8;
|
|
// Fill full blocks with 1s
|
|
for (int i = 0; i < full_blocks; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0xFF;
|
|
}
|
|
// Fill the remaining bits in the next block
|
|
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
|
|
} else {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = 0;
|
|
}
|
|
|
|
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0;
|
|
}
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
|
|
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
} else {
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
|
|
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
|
|
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
|
|
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
}
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];
|
|
}
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
|
|
thread_local_cur_attn_lse_[thread_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
|
|
}
|
|
} else {
|
|
if (cur_batch_idx != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
thread_cur_head_idx_[thread_id].first = batch_id;
|
|
thread_cur_head_idx_[thread_id].second = head_id;
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
}
|
|
}
|
|
},
|
|
// Merge the results of the remaining blocks.
|
|
[&](int thread_id) {
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (cur_head_id != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse;
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
});
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
|
|
int block_idx = block_table_after_retrieval_[i][j];
|
|
for (int k = 0; k < config_.q_head_num; k++) {
|
|
attn_sparsity[i * config_.q_head_num + k] +=
|
|
std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
|
|
// diff.count());
|
|
}
|
|
|
|
void KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx, int* block_table, int& max_block_num,
|
|
int* cache_seqlens) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
// initialize output_fp32_ and attn_lse_
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {
|
|
output_fp32_[batch_idx][i][j] = 0;
|
|
}
|
|
for (int j = 0; j < n_gqa_; j++) {
|
|
attn_lse_[batch_idx][i][j] = 0;
|
|
}
|
|
}
|
|
|
|
// clear top_similar_block_
|
|
while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();
|
|
}
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
for (int j = 0; j < config_.kv_head_num; j++) {
|
|
block_table_before_retrieval_kvhead_[batch_idx][i][j] = block_table[batch_idx * max_block_num + i];
|
|
block_similar_kv_head_[batch_idx][i][j] = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// printf("layer %d time of initializing attn: %f s\n", layer_idx,
|
|
// std::chrono::duration<double>(end - start).count());
|
|
}
|
|
void KVCache::retrieval_kvcache_kvhead_(const uint16_t* q_in_data, int init_block_num, int local_block_num,
|
|
int pick_block_num, int q_len, int generate_token_idx, int batch_size,
|
|
int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
max_block_num_after_retrieval_ = 0;
|
|
if (pick_block_num != -1 &&
|
|
(generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {
|
|
if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {
|
|
max_block_num_after_retrieval_ = max_block_num;
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
for (int j = 0; j < config_.kv_head_num; j++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
|
|
block_table_before_retrieval_kvhead_[batch_idx][i][j];
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
max_block_num_after_retrieval_ =
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < max_block_num_after_retrieval_; i++) {
|
|
for (int j = 0; j < config_.kv_head_num; j++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][j] =
|
|
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i]
|
|
[j];
|
|
}
|
|
}
|
|
|
|
if (cache_seqlens[batch_idx] % config_.block_len == 1) {
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;
|
|
int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];
|
|
for (int i = 0; i < config_.kv_head_num; i++) {
|
|
int last_block_idx =
|
|
block_table_before_retrieval_kvhead_[batch_idx][cache_seqlens[batch_idx] / config_.block_len][i];
|
|
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1]
|
|
[i] = last_block_idx;
|
|
block_table_after_retrieval_kvhead_[batch_idx][x - 1][i] = last_block_idx;
|
|
}
|
|
}
|
|
cache_seqlens_[batch_idx] = std::min(
|
|
cache_seqlens_[batch_idx], (cache_seqlens_[batch_idx] % config_.block_len) +
|
|
(init_block_num + pick_block_num + local_block_num) * config_.block_len);
|
|
}
|
|
}
|
|
} else if (pick_block_num != -1) {
|
|
max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);
|
|
calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,
|
|
init_block_num, local_block_num, pick_block_num, backend);
|
|
select_block_kvhead_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);
|
|
} else {
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
|
|
max_block_num_after_retrieval_ = max_block_num;
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
for (int j = 0; j < config_.kv_head_num; j++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// printf("layer %d time of retrieval kvcache: %f s\n", layer_idx,
|
|
// std::chrono::duration<double>(end - start).count());
|
|
}
|
|
void KVCache::calculate_sparsity_kvhead_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,
|
|
int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
seq_len_ = config_.block_len;
|
|
backend->do_work_stealing_job(
|
|
batch_size * config_.kv_head_num * max_block_num,
|
|
[&](int thread_id) {
|
|
thread_cur_head_idx_[thread_id].first = -1;
|
|
thread_cur_head_idx_[thread_id].second = -1;
|
|
},
|
|
[&](int task_id) {
|
|
int batch_id = task_id / (config_.kv_head_num * max_block_num);
|
|
int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;
|
|
int block_id = task_id % max_block_num;
|
|
int thread_id = WorkerPool::thread_local_id;
|
|
// If the block is out of the sequence length, skip it.
|
|
if (cache_seqlens[batch_id] / config_.block_len < block_id) {
|
|
return;
|
|
}
|
|
int block_idx = block_table[batch_id * max_block_num + block_id];
|
|
if (cache_seqlens_[batch_id] / config_.block_len == block_id) {
|
|
int seq_len = cache_seqlens_[batch_id] % config_.block_len;
|
|
if (seq_len == 0) return;
|
|
|
|
// Prepare the attention mask for the last block.
|
|
int full_blocks = seq_len / 8;
|
|
int remaining_bits = seq_len % 8;
|
|
|
|
// Fill full blocks with 1s
|
|
for (int i = 0; i < full_blocks; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0xFF;
|
|
}
|
|
// Fill the remaining bits in the next block
|
|
if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;
|
|
} else {
|
|
thread_local_attn_mask_[thread_id][full_blocks] = 0;
|
|
}
|
|
|
|
for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {
|
|
thread_local_attn_mask_[thread_id][i] = 0;
|
|
}
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,
|
|
0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),
|
|
GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
} else {
|
|
if (config_.kv_type == ggml_type::GGML_TYPE_F16) {
|
|
attn_with_kvcache_one_block_(
|
|
config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,
|
|
(void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +
|
|
head_id * n_gqa_ * config_.head_dim],
|
|
seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,
|
|
nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),
|
|
sin_.data());
|
|
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,
|
|
0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
} else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {
|
|
attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,
|
|
q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,
|
|
0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,
|
|
GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,
|
|
nullptr, nullptr, thread_local_attn_score_[thread_id].data(),
|
|
thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_attn_lse_[thread_id].data(),
|
|
thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());
|
|
dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),
|
|
thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);
|
|
}
|
|
}
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];
|
|
}
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (batch_id == cur_batch_idx && head_id == cur_head_id) {
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -
|
|
thread_local_cur_attn_lse_[thread_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;
|
|
}
|
|
} else {
|
|
if (cur_batch_idx != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim,
|
|
thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
thread_cur_head_idx_[thread_id].first = batch_id;
|
|
thread_cur_head_idx_[thread_id].second = head_id;
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =
|
|
thread_local_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
}
|
|
}
|
|
},
|
|
// Merge the results of the remaining blocks.
|
|
[&](int thread_id) {
|
|
int cur_batch_idx = thread_cur_head_idx_[thread_id].first;
|
|
int cur_head_id = thread_cur_head_idx_[thread_id].second;
|
|
if (cur_head_id != -1) {
|
|
mutex_[cur_batch_idx][cur_head_id]->lock();
|
|
for (int i = 0; i < n_gqa_; i++) {
|
|
float new_attn_lse;
|
|
if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
continue;
|
|
}
|
|
new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +
|
|
std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -
|
|
attn_lse_[cur_batch_idx][cur_head_id][i]));
|
|
ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,
|
|
std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));
|
|
ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,
|
|
std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));
|
|
for (int j = 0; j < config_.head_dim; j++) {
|
|
output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=
|
|
thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];
|
|
}
|
|
attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;
|
|
}
|
|
mutex_[cur_batch_idx][cur_head_id]->unlock();
|
|
}
|
|
});
|
|
|
|
for (int i = 0; i < batch_size; i++) {
|
|
for (int j = 0; j < max_block_num_after_retrieval_; j++) {
|
|
for (int k = 0; k < config_.q_head_num; k++) {
|
|
int block_idx = block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];
|
|
attn_sparsity[i * config_.q_head_num + k] +=
|
|
std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of calculating sparsity: %f s\n", layer_id_,
|
|
// diff.count());
|
|
}
|
|
void KVCache::calculate_block_similarity_kvhead_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,
|
|
int max_block_num, int* cache_seqlens, int init_block_num,
|
|
int local_block_num, int pick_block_num, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
backend->do_work_stealing_job(
|
|
batch_size * max_block_num, nullptr,
|
|
[&](int task_id) {
|
|
int batch_id = task_id / max_block_num;
|
|
int block_id = task_id % max_block_num;
|
|
int seq_len = cache_seqlens_[batch_id];
|
|
|
|
if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {
|
|
return;
|
|
}
|
|
int block_idx = block_table_before_retrieval_kvhead_[batch_id][block_id][0];
|
|
|
|
for (int head_id = 0; head_id < config_.q_head_num; head_id++) {
|
|
for (int i = 0; i < config_.head_dim; i++) {
|
|
float q_i = 0, qa_i = std::numeric_limits<float>::lowest();
|
|
for (int q_id = 0; q_id < q_len; q_id++) {
|
|
q_i += GGML_FP16_TO_FP32(
|
|
q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +
|
|
q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);
|
|
}
|
|
q_i /= q_len;
|
|
for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {
|
|
qa_i = std::max(
|
|
qa_i,
|
|
GGML_FP16_TO_FP32(
|
|
anchor_[layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *
|
|
config_.head_dim +
|
|
block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +
|
|
anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *
|
|
q_i);
|
|
}
|
|
block_similar_kv_head_[batch_id][block_id][head_id / n_gqa_] += qa_i;
|
|
}
|
|
}
|
|
},
|
|
nullptr);
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of calculating similarity: %f s\n", layer_idx,
|
|
// diff.count());
|
|
}
|
|
void KVCache::select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num,
|
|
int local_block_num, int pick_block_num) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
|
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
int cache_len_after_retrieval = 0;
|
|
if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;
|
|
for (int i = 0; i < max_block_num; i++) {
|
|
for (int j = 0; j < config_.kv_head_num; j++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {
|
|
for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;
|
|
block_id++) {
|
|
top_similar_block_[batch_idx].push(
|
|
std::make_pair(block_similar_kv_head_[batch_idx][block_id][head_id],
|
|
block_table_before_retrieval_kvhead_[batch_idx][block_id][head_id]));
|
|
if (top_similar_block_[batch_idx].size() > pick_block_num) {
|
|
top_similar_block_[batch_idx].pop();
|
|
}
|
|
}
|
|
|
|
int i = 0;
|
|
for (; i < init_block_num; i++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
|
|
block_table_before_retrieval_kvhead_[batch_idx][i][head_id];
|
|
}
|
|
while (!top_similar_block_[batch_idx].empty()) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = top_similar_block_[batch_idx].top().second;
|
|
top_similar_block_[batch_idx].pop();
|
|
i++;
|
|
}
|
|
for (; i < init_block_num + pick_block_num + local_block_num; i++) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
|
|
block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) -
|
|
local_block_num + i - init_block_num - pick_block_num]
|
|
[head_id];
|
|
}
|
|
if (cache_seqlens_[batch_idx] % config_.block_len != 0) {
|
|
block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =
|
|
block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)][head_id];
|
|
cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
|
|
i++;
|
|
} else {
|
|
cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;
|
|
}
|
|
for (int j = 0; j < i; j++) {
|
|
selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j]
|
|
[head_id] = block_table_after_retrieval_kvhead_[batch_idx][j][head_id];
|
|
}
|
|
}
|
|
cache_seqlens_[batch_idx] = cache_len_after_retrieval;
|
|
selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] =
|
|
(cache_len_after_retrieval + config_.block_len - 1) / config_.block_len;
|
|
}
|
|
|
|
// Timer end
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double> diff = end - start;
|
|
// printf("layer %d time of selecting block: %f s\n", layer_idx,
|
|
// diff.count())
|
|
}
|
|
|
|
void KVCache::get_attn_sparsity(const ggml_fp16_t* q_in, float* attn_sparsity, int layer_idx, int generate_token_idx,
|
|
int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,
|
|
int* block_table_origin, int* cache_seqlens_origin, int max_block_num_origin, int topk,
|
|
int local, WorkerPool* backend) {
|
|
// Timer start
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
layer_id_ = layer_idx;
|
|
int thread_num = backend->get_thread_num();
|
|
batch_size = 1;
|
|
|
|
const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);
|
|
|
|
quantize_q_(q_in_data, batch_size);
|
|
if (config_.retrieval_type == RetrievalType::LAYER) {
|
|
attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
|
|
retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens,
|
|
max_block_num, backend);
|
|
calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,
|
|
cache_seqlens_origin, backend);
|
|
} else if (config_.retrieval_type == RetrievalType::KVHEAD) {
|
|
attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);
|
|
retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx,
|
|
cache_seqlens, max_block_num, backend);
|
|
calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,
|
|
cache_seqlens_origin, backend);
|
|
}
|
|
}
|
|
|
|
void KVCache::attn_with_kvcache_one_block_(int head_dim, int bsz,
|
|
ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0
|
|
// [bsz, head_dim]
|
|
// Quantization is always on the head_dim dimension (per_token). If
|
|
// head_dim % 32 != 0, an error will be raised. The size must be bsz *
|
|
// head_dim/32 * qtype_size.
|
|
const void* q,
|
|
|
|
int past_kv_len, int past_kv_offset,
|
|
bool is_full_attn, // true indicates a full 1 mask
|
|
// If is_full_attn = false, a bit matrix representing the mask is
|
|
// passed. [bsz, past_kv_len]
|
|
const uint8_t* attn_mask,
|
|
|
|
ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,
|
|
// q4_0, q8_0
|
|
int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an
|
|
// error
|
|
// [seq_len, head_dim]
|
|
// If quant_type == 0, head_dim % 32 must be 0.
|
|
// If quant_type == 1, seq_len % 32 must be 0.
|
|
const void* k_cache,
|
|
|
|
// k_anchor_type must be fp16
|
|
int num_k_anchor, // num_k_anchor == 0 indicates no anchor
|
|
// [num_k_anchor, head_dim]
|
|
const void* k_cache_anchors,
|
|
// Each token is associated with the nearest previous position's anchor,
|
|
// with the same distance.
|
|
const int* k_cache_anchor_pos,
|
|
|
|
// v_cache similar to k_cache
|
|
ggml_type v_type, int v_quant_type,
|
|
// [head_dim, seq_len]
|
|
const void* v_cache, int num_v_anchor, const void* v_cache_anchors,
|
|
const int* v_cache_anchor_pos,
|
|
|
|
// Pre-allocated buffer for intermediate calculations [bsz,
|
|
// past_kv_len]. No malloc is performed inside this function.
|
|
float* attn_score,
|
|
|
|
// Output: [bsz, head_dim], with the same type as q_type
|
|
void* output,
|
|
// [bsz]
|
|
float* lse,
|
|
|
|
// Pre-allocated temporary buffer with sufficient size:
|
|
// (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *
|
|
// head_dim + past_kv_len * head_dim / 32) bytes.
|
|
void* draft,
|
|
|
|
// Apply rotary embedding online
|
|
const int* rotary_angle, const void* rotary_cos, const void* rotary_sin
|
|
// rotary_cos=None,
|
|
// rotary_sin=None,
|
|
// cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
|
// cache_batch_idx: Optional[torch.Tensor] = None,
|
|
// rotary_interleaved=True,
|
|
|
|
// // Not supported for now
|
|
// window_size=(-1, -1), # -1 means infinite context window
|
|
// alibi_slopes=None,
|
|
) {
|
|
assert(head_dim % 32 == 0);
|
|
assert(k_quant_type == 0);
|
|
assert(v_quant_type == 1);
|
|
assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);
|
|
if (q_type == GGML_TYPE_F16) {
|
|
assert(k_type == GGML_TYPE_F16);
|
|
assert(v_type == GGML_TYPE_F16);
|
|
|
|
// attn = q * k + q * k_anchor
|
|
// TODO: anchor
|
|
assert(num_k_anchor == 0);
|
|
|
|
if (rotary_angle != nullptr) {
|
|
ggml_fp16_t* k_cache_with_rope_fp16 =
|
|
(reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
|
|
sizeof(float) * bsz * head_dim);
|
|
// dequant k_cache and apply rope
|
|
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
|
|
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
|
|
|
|
// k(i)cos(i) -> k_rope(i)
|
|
// k(i)sin(i+l) -> k_rope(i+l)
|
|
|
|
// k(i)cos(i) -> k_rope(i)
|
|
// -k(i)sin(i-l) -> k_rope(i-l)
|
|
|
|
std::vector<float> block_fp32(32);
|
|
for (int k = 0; k < past_kv_len; k++) {
|
|
int angle = rotary_angle[k];
|
|
for (int l = 0; l < head_dim / 32; l++) {
|
|
for (int m = 0; m < 32; m++) {
|
|
float x = GGML_FP16_TO_FP32(((ggml_fp16_t*)k_cache)[k * head_dim + l * 32 + m]);
|
|
float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);
|
|
float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);
|
|
|
|
if (l * 32 + m < head_dim / 2) {
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(x * cos_val);
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] = GGML_FP32_TO_FP16(-x * sin_val);
|
|
} else {
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =
|
|
GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]) + x * sin_val);
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] = GGML_FP32_TO_FP16(
|
|
GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2]) - x * cos_val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache_with_rope_fp16, head_dim, (ggml_fp16_t*)q,
|
|
head_dim, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
|
|
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
} else {
|
|
bool ok = llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache, head_dim, (ggml_fp16_t*)q, head_dim,
|
|
attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,
|
|
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
|
|
if (!ok) {
|
|
printf("llamafile_sgemm failed\n");
|
|
}
|
|
}
|
|
// attn = attn * scale
|
|
float scale_factor = 1.0 / std::sqrt(float(head_dim));
|
|
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
|
|
|
|
// attn = attn & mask
|
|
if (!is_full_attn) {
|
|
for (int i = 0; i < bsz; i++) {
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
int index = i * past_kv_len + j;
|
|
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
|
|
attn_score[index] = std::numeric_limits<float>::lowest();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// attn = softmax(attn)
|
|
for (int i = 0; i < bsz; i++) {
|
|
float sum_exp = 0;
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);
|
|
sum_exp += attn_score[i * past_kv_len + j];
|
|
}
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
attn_score[i * past_kv_len + j] /= sum_exp;
|
|
}
|
|
if (lse != nullptr) {
|
|
lse[i] = std::log(sum_exp);
|
|
}
|
|
}
|
|
|
|
// output = attn * v + attn * v_anchor
|
|
// std::vector<float> sum(bsz * head_dim);
|
|
float* sum =
|
|
reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);
|
|
|
|
// float* attn_score_fp16(bsz, past_kv_len)
|
|
ggml_fp16_t* attn_score_fp16 = (reinterpret_cast<ggml_fp16_t*>(reinterpret_cast<char*>(draft) +
|
|
sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
|
|
sizeof(float) * bsz * head_dim));
|
|
|
|
for (int i = 0; i < bsz * past_kv_len; i++) {
|
|
attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);
|
|
}
|
|
|
|
// TODO: anchor
|
|
assert(num_v_anchor == 0);
|
|
bool ok = llamafile_sgemm(head_dim, bsz, past_kv_len, (ggml_fp16_t*)v_cache, past_kv_len,
|
|
(ggml_fp16_t*)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE,
|
|
v_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
if (!ok) {
|
|
printf("llamafile_sgemm failed\n");
|
|
}
|
|
|
|
// copy to output
|
|
for (int i = 0; i < bsz; i++) {
|
|
for (int j = 0; j < head_dim; j++) {
|
|
((float*)output)[i * head_dim + j] = sum[i * head_dim + j];
|
|
}
|
|
}
|
|
} else {
|
|
assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);
|
|
assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);
|
|
|
|
// attn = q * k + q * k_anchor
|
|
// TODO: anchor
|
|
assert(num_k_anchor == 0);
|
|
|
|
if (rotary_angle != nullptr) {
|
|
ggml_fp16_t* k_cache_with_rope_fp16 =
|
|
(reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
|
|
sizeof(float) * bsz * head_dim);
|
|
block_q4_0* k_cache_with_rope_q4 =
|
|
(reinterpret_cast<block_q4_0*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +
|
|
sizeof(float) * bsz * head_dim) +
|
|
sizeof(ggml_fp16_t) * bsz * head_dim;
|
|
// dequant k_cache and apply rope
|
|
// k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)
|
|
// k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)
|
|
|
|
// k(i)cos(i) -> k_rope(i)
|
|
// k(i)sin(i+l) -> k_rope(i+l)
|
|
|
|
// k(i)cos(i) -> k_rope(i)
|
|
// -k(i)sin(i-l) -> k_rope(i-l)
|
|
|
|
std::vector<float> block_fp32(32);
|
|
for (int k = 0; k < past_kv_len; k++) {
|
|
int angle = rotary_angle[k];
|
|
for (int l = 0; l < head_dim / 32; l++) {
|
|
block_q4_0 block = ((block_q4_0*)k_cache)[k * head_dim / 32 + l];
|
|
dequantize_row_q4_0(&block, block_fp32.data(), 32);
|
|
for (int m = 0; m < 32; m++) {
|
|
float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);
|
|
float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);
|
|
|
|
if (l * 32 + m < head_dim / 2) {
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] =
|
|
GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);
|
|
} else {
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m] += GGML_FP32_TO_FP16(block_fp32[m] * sin_val);
|
|
k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] -=
|
|
GGML_FP32_TO_FP16(block_fp32[m] * cos_val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// quantize k_cache_with_rope_fp16
|
|
for (int k = 0; k < past_kv_len; k++) {
|
|
for (int l = 0; l < head_dim / 32; l++) {
|
|
for (int m = 0; m < 32; m++) {
|
|
block_fp32[m] = GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);
|
|
}
|
|
quantize_row_q4_0(block_fp32.data(), &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);
|
|
}
|
|
}
|
|
|
|
llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache_with_rope_q4, head_dim / 32, (block_q8_0*)q,
|
|
head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,
|
|
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
} else {
|
|
llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache, head_dim / 32, (block_q8_0*)q,
|
|
head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,
|
|
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
}
|
|
|
|
// attn = attn * scale
|
|
float scale_factor = 1.0 / std::sqrt(float(head_dim));
|
|
ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);
|
|
|
|
// attn = attn & mask
|
|
if (!is_full_attn) {
|
|
for (int i = 0; i < bsz; i++) {
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
int index = i * past_kv_len + j;
|
|
if (!(attn_mask[j / 8] & (1 << (j % 8)))) {
|
|
attn_score[index] = std::numeric_limits<float>::lowest();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// attn = softmax(attn)
|
|
for (int i = 0; i < bsz; i++) {
|
|
float sum_exp = 0;
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);
|
|
sum_exp += attn_score[i * past_kv_len + j];
|
|
}
|
|
for (int j = 0; j < past_kv_len; j++) {
|
|
attn_score[i * past_kv_len + j] /= sum_exp;
|
|
}
|
|
if (lse != nullptr) {
|
|
lse[i] = std::log(sum_exp);
|
|
}
|
|
}
|
|
|
|
// output = attn * v + attn * v_anchor
|
|
// std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);
|
|
block_q8_0* attn_q8_0 = reinterpret_cast<block_q8_0*>(draft);
|
|
quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);
|
|
// std::vector<float> sum(bsz * head_dim);
|
|
float* sum =
|
|
reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);
|
|
// TODO: anchor
|
|
assert(num_v_anchor == 0);
|
|
llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0*)v_cache, past_kv_len / 32, attn_q8_0,
|
|
past_kv_len / 32, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_Q8_0,
|
|
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
|
|
|
quantize_row_q8_0(sum, (block_q8_0*)output, bsz * head_dim);
|
|
}
|
|
}
|