Files
ktransformers/kt-kernel/python/cli/utils/kv_cache_calculator.py
Oql 56cbd69ac4 kt-cli enhancement (#1834)
* [feat]: redesign kt run interactive configuration with i18n support

- Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port)
- Add configuration save/load system (~/.ktransformers/run_configs.yaml)
- Add i18n support for kt chat (en/zh translations)
- Add universal input validators with auto-retry and Chinese comma support
- Add port availability checker with auto-suggestion
- Add parser configuration (--tool-call-parser, --reasoning-parser)
- Remove tuna command and clean up redundant files
- Fix: variable reference bug in run.py, filter to show only MoE models

* [feat]: unify model selection UI and enable shared experts fusion by default

- Unify kt run model selection table with kt model list display
  * Add Total size, MoE Size, Repo, and SHA256 status columns
  * Use consistent formatting and styling
  * Improve user decision-making with more information

- Enable --disable-shared-experts-fusion by default
  * Change default value from False to True
  * Users can still override with --enable-shared-experts-fusion

* [feat]: improve kt chat with performance metrics and better CJK support

- Add performance metrics display after each response
  * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token)
  * Accurate input/output token counts using model tokenizer
  * Fallback to estimation if tokenizer unavailable
  * Metrics shown in dim style (not prominent)

- Fix Chinese character input issues
  * Replace Prompt.ask() with console.input() for better CJK support
  * Fixes backspace deletion showing half-characters

- Suppress NumPy subnormal warnings
  * Filter "The value of the smallest subnormal" warnings
  * Cleaner CLI output on certain hardware environments

* [fix]: correct TTFT measurement in kt chat

- Move start_time initialization before API call
- Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms
- Now correctly measures time from request sent to first token received

* [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案

* [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力

* [docs]: 添加 Clawdbot 飞书接入教程链接

* [feat]: improve CLI table display, model verification, and chat experience

- Add sequence number (#) column to all model tables by default
- Filter kt edit to show only MoE GPU models (exclude AMX)
- Extend kt model verify to check *.json and *.py files in addition to weights
- Fix re-verification bug where repaired files caused false failures
- Suppress tokenizer debug output in kt chat token counting

* [fix]: fix cpu cores.

---------

Co-authored-by: skqliao <skqliao@gmail.com>
2026-02-04 16:44:54 +08:00

208 lines
6.6 KiB
Python

#!/usr/bin/env python3
"""
KV Cache Size Calculator for SGLang
This script calculates the KV cache size in GB for a given model and number of tokens.
It follows the same logic as in sglang/srt/model_executor/model_runner.py
"""
import os
import sys
import torch
from transformers import AutoConfig
# Add sglang to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
def get_dtype_bytes(dtype_str: str) -> int:
"""Get the number of bytes for a given dtype string."""
dtype_map = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"float8_e4m3fn": 1,
"float8_e5m2": 1,
"auto": 2, # Usually defaults to bfloat16
}
return dtype_map.get(dtype_str, 2)
def get_kv_size_gb(
model_path: str,
max_total_tokens: int,
tp: int = 1,
dtype: str = "auto",
verbose: bool = True,
) -> dict:
"""
Calculate the KV cache size in GB for a given model and number of tokens.
Args:
model_path: Path to the model
max_total_tokens: Maximum number of tokens to cache
tp: Tensor parallelism size
dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)
verbose: Whether to print detailed information
Returns:
dict: Dictionary containing calculation details
"""
# Load model config
model_config = ModelConfig(model_path, dtype=dtype)
hf_config = model_config.hf_config
# Determine dtype bytes
dtype_bytes = get_dtype_bytes(dtype)
if dtype == "auto":
# Auto dtype usually becomes bfloat16
dtype_bytes = 2
# Number of layers
num_layers = model_config.num_attention_layers
# Check if it's MLA (Multi-head Latent Attention) model
is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA"
result = {
"model_path": model_path,
"max_total_tokens": max_total_tokens,
"tp": tp,
"dtype": dtype,
"dtype_bytes": dtype_bytes,
"num_layers": num_layers,
"is_mla": is_mla,
}
if is_mla:
# MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)
kv_lora_rank = model_config.kv_lora_rank
qk_rope_head_dim = model_config.qk_rope_head_dim
# Calculate cell size (per token)
cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes
result.update(
{
"kv_lora_rank": kv_lora_rank,
"qk_rope_head_dim": qk_rope_head_dim,
"cell_size_bytes": cell_size,
}
)
# Check if it's NSA (Native Sparse Attention) model
if is_deepseek_nsa(hf_config):
index_head_dim = get_nsa_index_head_dim(hf_config)
indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4
indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)
indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes
cell_size += indexer_cell_size
result.update(
{
"is_nsa": True,
"index_head_dim": index_head_dim,
"indexer_cell_size_bytes": indexer_cell_size,
"total_cell_size_bytes": cell_size,
}
)
else:
result["is_nsa"] = False
else:
# Standard MHA models
num_kv_heads = model_config.get_num_kv_heads(tp)
head_dim = model_config.head_dim
v_head_dim = model_config.v_head_dim
# Calculate cell size (per token)
cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes
result.update(
{
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"v_head_dim": v_head_dim,
"cell_size_bytes": cell_size,
}
)
# Calculate total KV cache size
total_size_bytes = max_total_tokens * cell_size
total_size_gb = total_size_bytes / (1024**3)
# For MHA models with separate K and V buffers
if not is_mla:
k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes
v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes
k_size_gb = k_size_bytes / (1024**3)
v_size_gb = v_size_bytes / (1024**3)
result.update(
{
"k_size_gb": k_size_gb,
"v_size_gb": v_size_gb,
}
)
result.update(
{
"total_size_bytes": total_size_bytes,
"total_size_gb": total_size_gb,
}
)
if verbose:
print(f"Model: {model_path}")
print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}")
print(f"Architecture: {'MLA' if is_mla else 'MHA'}")
print(f"Layers: {num_layers}")
if is_mla:
print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}")
if result.get("is_nsa"):
print(f"NSA Index Head Dim: {index_head_dim}")
print(
f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})"
)
else:
print(f"Cell size: {cell_size} bytes")
else:
print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}")
print(f"Cell size: {cell_size} bytes")
print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB")
print(f"Total KV Cache Size: {total_size_gb:.2f} GB")
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="Calculate KV cache size for a model")
parser.add_argument("model_path", help="Path to the model")
parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)")
parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")
args = parser.parse_args()
result = get_kv_size_gb(
args.model_path,
args.max_total_tokens,
tp=args.tp,
dtype=args.dtype,
verbose=not args.quiet,
)
if args.quiet:
print(f"{result['total_size_gb']:.2f}")
if __name__ == "__main__":
main()