Files
ktransformers/archive/csrc/balance_serve/kvc2/src/kvc2.h
Jiaqi Liao 57d14d22bc Refactor: restructure repository to focus on kt-kernel and KT-SFT modulesq recon (#1581)
* refactor: move legacy code to archive/ directory

  - Moved ktransformers, csrc, third_party, merge_tensors to archive/
  - Moved build scripts and configurations to archive/
  - Kept kt-kernel, KT-SFT, doc, and README files in root
  - Preserved complete git history for all moved files

* refactor: restructure repository to focus on kt-kernel and KT-SFT modules

* fix README

* fix README

* fix README

* fix README

* docs: add performance benchmarks to kt-kernel section

Add comprehensive performance data for kt-kernel to match KT-SFT's presentation:
- AMX kernel optimization: 21.3 TFLOPS (3.9× faster than PyTorch)
- Prefill phase: up to 20× speedup vs baseline
- Decode phase: up to 4× speedup
- NUMA optimization: up to 63% throughput improvement
- Multi-GPU (8×L20): 227.85 tokens/s total throughput with DeepSeek-R1 FP8

Source: https://lmsys.org/blog/2025-10-22-KTransformers/

This provides users with concrete performance metrics for both core modules,
making it easier to understand the capabilities of each component.

* refactor: improve kt-kernel performance data with specific hardware and models

Replace generic performance descriptions with concrete benchmarks:
- Specify exact hardware: 8×L20 GPU + Xeon Gold 6454S, Single/Dual-socket Xeon + AMX
- Include specific models: DeepSeek-R1-0528 (FP8), DeepSeek-V3 (671B)
- Show detailed metrics: total throughput, output throughput, concurrency details
- Match KT-SFT presentation style for consistency

This provides users with actionable performance data they can use to evaluate
hardware requirements and expected performance for their use cases.

* fix README

* docs: clean up performance table and improve formatting

* add pic for README

* refactor: simplify .gitmodules and backup legacy submodules

- Remove 7 legacy submodules from root .gitmodules (archive/third_party/*)
- Keep only 2 active submodules for kt-kernel (llama.cpp, pybind11)
- Backup complete .gitmodules to archive/.gitmodules
- Add documentation in archive/README.md for researchers who need legacy submodules

This reduces initial clone size by ~500MB and avoids downloading unused dependencies.

* refactor: move doc/ back to root directory

Keep documentation in root for easier access and maintenance.

* refactor: consolidate all images to doc/assets/

- Move kt-kernel/assets/heterogeneous_computing.png to doc/assets/
- Remove KT-SFT/assets/ (images already in doc/assets/)
- Update KT-SFT/README.md image references to ../doc/assets/
- Eliminates ~7.9MB image duplication
- Centralizes all documentation assets in one location

* fix pic path for README
2025-11-10 17:42:26 +08:00

139 lines
4.2 KiB
C++

#pragma once
#include <torch/torch.h>
#include <cstdint>
#include <optional>
#include <vector>
#include "defs.h"
#include "model_config.h"
namespace kvc2 {
struct GPUPageCacheConfig {
bool gpu_only;
std::vector<size_t> gpu_devices_id;
size_t layer_count;
size_t total_kvcache_pages;
size_t num_token_per_page;
size_t num_k_heads;
size_t k_head_dim;
bool full_kv_cache_on_each_gpu = false;
bool k_cache_on = true;
bool v_cache_on = true;
torch::ScalarType tensor_type;
// for cuda stream manager
size_t num_streams_per_device = 4;
};
struct KVC2Config {
bool k_cache_on = true;
bool v_cache_on = true;
bool gpu_only = false;
bool load_from_disk = true;
bool save_to_disk = true;
std::string path;
std::string config_path;
TokenLength num_token_per_page = 256;
size_t memory_pool_size = 10e9;
size_t evict_count = 20;
std::optional<GPUPageCacheConfig> gpu_cache_config = std::nullopt;
size_t metrics_port;
double recompute_ratio = 0.2;
};
class DoubleCacheHandleInterface;
class KVC2Interface {
public:
virtual ~KVC2Interface() = default;
virtual void load() = 0;
virtual void save() = 0;
/*
Raw Insert
Insert kvcache from kvcache_data to disk.
info: cache info
id: start pointer of token array
length: length of token array
kvcache_data: data of kvcache
This will firstly match the ID array with the existing kvcache, and then insert the unmatched kvcache to disk.
*/
virtual void raw_insert(ModelName model_name, QuantType quant_type, Token* id, TokenLength length,
const std::vector<layer_data>& k_cache, const std::vector<layer_data>& v_cache) = 0;
/*
Raw Read
Read kvcache from disk to user specified pointers.
info: cache info
id: start pointer of token array
length: length of token array
kvcache_data: data of kvcache
Return: matched length of prefix, in tokens
This will not read from memory pool, it directly read from disk.
*/
virtual TokenLength raw_read(ModelName model_name, QuantType quant_type, Token* id, TokenLength length,
const std::vector<layer_data>& k_cache, const std::vector<layer_data>& v_cache) = 0;
/*
Lookup
Lookup kvcache and load it from disk to memory pool if needed.
info: cache info
id: start pointer of token array
length: length of token array
Return: kvc2_handle, holds kvcache until being released.
if not found, matched_length will return 0.
if memory pool is full, return nullptr
*/
virtual std::shared_ptr<DoubleCacheHandleInterface> lookup(ModelName model_name, QuantType quant_type, Token* id,
TokenLength length, TokenLength estimated_length) = 0;
/*
Lookup and allocate to gpu
info.is_k_cache does not matter here
*/
virtual std::shared_ptr<DoubleCacheHandleInterface> lookup_to_gpu(ModelName model_name, QuantType quant_type,
Token* id, TokenLength length,
TokenLength estimated_length) = 0;
virtual void lookup_to_gpu_async(ModelName model_name, QuantType quant_type, Token* id, TokenLength length,
TokenLength estimated_length,
std::function<void(std::shared_ptr<DoubleCacheHandleInterface>)> call_back) = 0;
virtual std::pair<std::vector<torch::Tensor>, std::vector<torch::Tensor>> get_kvcache() = 0;
virtual void debug() = 0;
};
std::shared_ptr<KVC2Interface> create_kvc2(KVC2Config config);
enum MatchStatus {
Exact,
Partial,
NotMatchExact,
NotMatchPartial,
};
class DoubleCacheHandleInterface {
public:
virtual ~DoubleCacheHandleInterface() = default;
virtual TokenLength matched_length() = 0;
virtual std::vector<MatchStatus> matched_status() = 0;
virtual std::vector<layer_data> handle_data(bool is_key_cache) = 0;
virtual bool to_gpu() = 0;
virtual void to_gpu_async(std::function<void(bool)> call_back) = 0;
virtual std::vector<size_t> get_gpu_block_idx() = 0;
virtual std::vector<size_t> get_gpu_attached_block_idx() = 0;
virtual void append_tokens(Token* tokens, TokenLength length) = 0; // update generated tokens
virtual void debug() = 0;
};
}; // namespace kvc2