Files
ktransformers/kt-kernel/operators/kvcache/kvcache_load_dump.cpp
2025-10-12 05:13:00 +00:00

102 lines
4.6 KiB
C++

/**
* @Description :
* @Author : Jianwei Dong
* @Date : 2024-08-26 22:47:06
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include <chrono>
#include "kvcache.h"
void KVCache::load_kvcache(std::string tensor_file_path, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);
if (!ifs_tensor) {
throw std::runtime_error("Failed to open tensor file");
}
ifs_tensor.read(reinterpret_cast<char*>(&cache_total_len_), sizeof(cache_total_len_));
int past_block_num = (cache_total_len_ + config_.block_len - 1) / config_.block_len;
printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len_, past_block_num);
for (int i = 0; i < config_.layer_num; ++i) {
past_block_num_[i] = past_block_num;
}
ifs_tensor.read(reinterpret_cast<char*>(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t));
for (int i = 0; i < config_.layer_num; ++i) {
for (int j = 0; j < config_.kv_head_num; ++j) {
for (int k = 0; k < past_block_num_[i]; ++k) {
if (config_.kv_type == GGML_TYPE_F16) {
ifs_tensor.read(reinterpret_cast<char*>(k_cache_fp16_[i][j][k].data()),
k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
ifs_tensor.read(reinterpret_cast<char*>(v_cache_fp16_[i][j][k].data()),
v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));
} else if (config_.kv_type == GGML_TYPE_Q4_0) {
ifs_tensor.read(reinterpret_cast<char*>(k_cache_q4[i][j][k].data()),
k_cache_q4[i][j][k].size() * sizeof(block_q4_0));
ifs_tensor.read(reinterpret_cast<char*>(v_cache_q4[i][j][k].data()),
v_cache_q4[i][j][k].size() * sizeof(block_q4_0));
}
}
}
for (int k = 0; k < past_block_num_[i]; ++k) {
for (int l = 0; l < config_.block_len; l++) {
ifs_tensor.read(reinterpret_cast<char*>(importance_[i][k][l].data()),
importance_[i][k][l].size() * sizeof(ggml_fp16_t));
}
}
}
ifs_tensor.close();
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
printf("time of load: %f s\n", diff.count());
}
void KVCache::dump_kvcache(int* block_table, int cache_total_len, std::string tensor_file_path, WorkerPool* backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
std::ofstream ofs(tensor_file_path, std::ios::binary);
printf("dump_kvcache: %s\n", tensor_file_path.c_str());
if (!ofs.is_open()) {
std::cerr << "Cannot open file " << tensor_file_path << std::endl;
return;
}
ofs.write(reinterpret_cast<const char*>(&cache_total_len), sizeof(cache_total_len));
int past_block_num = (cache_total_len + config_.block_len - 1) / config_.block_len;
printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len, past_block_num);
ofs.write(reinterpret_cast<const char*>(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t));
for (int i = 0; i < config_.layer_num; ++i) {
for (int j = 0; j < config_.kv_head_num; ++j) {
for (int k = 0; k < past_block_num; ++k) {
int block_idx = block_table[k];
if (config_.kv_type == GGML_TYPE_F16) {
ofs.write(reinterpret_cast<const char*>(k_cache_fp16_[i][j][block_idx].data()),
k_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t));
ofs.write(reinterpret_cast<const char*>(v_cache_fp16_[i][j][block_idx].data()),
v_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t));
} else if (config_.kv_type == GGML_TYPE_Q4_0) {
ofs.write(reinterpret_cast<const char*>(k_cache_q4[i][j][block_idx].data()),
k_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0));
ofs.write(reinterpret_cast<const char*>(v_cache_q4[i][j][block_idx].data()),
v_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0));
}
}
}
for (int k = 0; k < past_block_num; ++k) {
int block_idx = block_table[k];
for (int l = 0; l < config_.block_len; l++) {
ofs.write(reinterpret_cast<const char*>(importance_[i][block_idx][l].data()),
importance_[i][block_idx][l].size() * sizeof(ggml_fp16_t));
}
}
}
ofs.close();
// Timer end
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
printf("time of dump: %f s\n", diff.count());
}