kt-cli enhancement (#1834)

* [feat]: redesign kt run interactive configuration with i18n support

- Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port)
- Add configuration save/load system (~/.ktransformers/run_configs.yaml)
- Add i18n support for kt chat (en/zh translations)
- Add universal input validators with auto-retry and Chinese comma support
- Add port availability checker with auto-suggestion
- Add parser configuration (--tool-call-parser, --reasoning-parser)
- Remove tuna command and clean up redundant files
- Fix: variable reference bug in run.py, filter to show only MoE models

* [feat]: unify model selection UI and enable shared experts fusion by default

- Unify kt run model selection table with kt model list display
  * Add Total size, MoE Size, Repo, and SHA256 status columns
  * Use consistent formatting and styling
  * Improve user decision-making with more information

- Enable --disable-shared-experts-fusion by default
  * Change default value from False to True
  * Users can still override with --enable-shared-experts-fusion

* [feat]: improve kt chat with performance metrics and better CJK support

- Add performance metrics display after each response
  * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token)
  * Accurate input/output token counts using model tokenizer
  * Fallback to estimation if tokenizer unavailable
  * Metrics shown in dim style (not prominent)

- Fix Chinese character input issues
  * Replace Prompt.ask() with console.input() for better CJK support
  * Fixes backspace deletion showing half-characters

- Suppress NumPy subnormal warnings
  * Filter "The value of the smallest subnormal" warnings
  * Cleaner CLI output on certain hardware environments

* [fix]: correct TTFT measurement in kt chat

- Move start_time initialization before API call
- Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms
- Now correctly measures time from request sent to first token received

* [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案

* [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力

* [docs]: 添加 Clawdbot 飞书接入教程链接

* [feat]: improve CLI table display, model verification, and chat experience

- Add sequence number (#) column to all model tables by default
- Filter kt edit to show only MoE GPU models (exclude AMX)
- Extend kt model verify to check *.json and *.py files in addition to weights
- Fix re-verification bug where repaired files caused false failures
- Suppress tokenizer debug output in kt chat token counting

* [fix]: fix cpu cores.

---------

Co-authored-by: skqliao <skqliao@gmail.com>
This commit is contained in:
Oql
2026-02-04 16:44:54 +08:00
committed by GitHub
parent 4f64665758
commit 56cbd69ac4
23 changed files with 10327 additions and 781 deletions

View File

@@ -0,0 +1,413 @@
#!/usr/bin/env python3
"""
快速分析 MoE 模型 - 基于 config.json
(复用 sglang 的模型注册表和判断逻辑)
"""
import json
import hashlib
from pathlib import Path
from typing import Optional, Dict, Any
def _get_sglang_moe_architectures():
"""
从 sglang 的模型注册表获取所有 MoE 架构
复用 sglang 的代码,这样 sglang 更新后自动支持新模型
"""
try:
import sys
# 添加 sglang 路径到 sys.path
sglang_path = Path("/mnt/data2/ljq/sglang/python")
if sglang_path.exists() and str(sglang_path) not in sys.path:
sys.path.insert(0, str(sglang_path))
# 直接导入 sglang 的 ModelRegistry
# 注意:这需要 sglang 及其依赖正确安装
from sglang.srt.models.registry import ModelRegistry
# 获取所有支持的架构
supported_archs = ModelRegistry.get_supported_archs()
# 过滤出 MoE 模型(名称包含 Moe
moe_archs = {arch for arch in supported_archs if "Moe" in arch or "moe" in arch.lower()}
# 手动添加一些不带 "Moe" 字样但是 MoE 模型的架构
# DeepSeek V2/V3 系列
deepseek_moe = {arch for arch in supported_archs if arch.startswith("Deepseek") or arch.startswith("deepseek")}
moe_archs.update(deepseek_moe)
# DBRX 也是 MoE 模型
dbrx_moe = {arch for arch in supported_archs if "DBRX" in arch or "dbrx" in arch.lower()}
moe_archs.update(dbrx_moe)
# Grok 也是 MoE 模型
grok_moe = {arch for arch in supported_archs if "Grok" in arch or "grok" in arch.lower()}
moe_archs.update(grok_moe)
return moe_archs
except Exception as e:
# 如果 sglang 不可用,返回空集合
# 这种情况下,后续会使用配置文件中的其他判断方法
import warnings
warnings.warn(f"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.")
return set()
# 获取 MoE 架构列表(优先从 sglang 获取)
MOE_ARCHITECTURES = _get_sglang_moe_architectures()
def _get_cache_file():
"""获取集中式缓存文件路径"""
cache_dir = Path.home() / ".ktransformers" / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "moe_analysis_v2.json"
def _load_all_cache():
"""加载所有缓存数据"""
cache_file = _get_cache_file()
if not cache_file.exists():
return {}
try:
with open(cache_file, "r") as f:
return json.load(f)
except Exception:
return {}
def _save_all_cache(cache_data):
"""保存所有缓存数据"""
cache_file = _get_cache_file()
try:
with open(cache_file, "w") as f:
json.dump(cache_data, f, indent=2)
except Exception as e:
import warnings
warnings.warn(f"Failed to save MoE cache: {e}")
def _compute_config_fingerprint(config_path: Path) -> Optional[str]:
"""计算 config.json 指纹"""
if not config_path.exists():
return None
try:
stat = config_path.stat()
# 使用文件大小和修改时间作为指纹
fingerprint_str = f"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}"
return hashlib.md5(fingerprint_str.encode()).hexdigest()
except Exception:
return None
def _load_cache(model_path: Path) -> Optional[Dict[str, Any]]:
"""加载指定模型的缓存"""
model_path_str = str(model_path.resolve())
all_cache = _load_all_cache()
if model_path_str not in all_cache:
return None
try:
cache_entry = all_cache[model_path_str]
# 验证缓存版本
cache_version = cache_entry.get("cache_version", 0)
if cache_version != 2:
return None
# 验证 config.json 指纹
config_path = model_path / "config.json"
current_fingerprint = _compute_config_fingerprint(config_path)
if cache_entry.get("fingerprint") != current_fingerprint:
return None
return cache_entry.get("result")
except Exception:
return None
def _save_cache(model_path: Path, result: Dict[str, Any]):
"""保存指定模型的缓存"""
model_path_str = str(model_path.resolve())
try:
config_path = model_path / "config.json"
fingerprint = _compute_config_fingerprint(config_path)
all_cache = _load_all_cache()
all_cache[model_path_str] = {
"fingerprint": fingerprint,
"result": result,
"cache_version": 2,
"last_updated": __import__("datetime").datetime.now().isoformat(),
}
_save_all_cache(all_cache)
except Exception as e:
import warnings
warnings.warn(f"Failed to save MoE cache for {model_path}: {e}")
def _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]:
"""读取 config.json 文件
参考 sglang 的 get_config() 实现
"""
config_path = model_path / "config.json"
if not config_path.exists():
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config
except Exception:
return None
def _is_moe_model(config: Dict[str, Any]) -> bool:
"""判断是否是 MoE 模型
参考 sglang 的模型注册表和架构识别方式
"""
# 方法1: 检查架构名称
architectures = config.get("architectures", [])
if any(arch in MOE_ARCHITECTURES for arch in architectures):
return True
# 方法2: 检查是否有 MoE 相关字段Mistral 格式)
if config.get("moe"):
return True
# 方法3: 检查是否有 num_experts 或其变体字段
# 需要检查 text_config对于某些多模态模型
text_config = config.get("text_config", config)
# 检查各种专家数量字段
if (
text_config.get("num_experts") or text_config.get("num_local_experts") or text_config.get("n_routed_experts")
): # Kimi-K2 使用这个字段
return True
return False
def _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]:
"""从 config 中提取 MoE 参数
参考 sglang 的各种 MoE 模型实现
"""
# 处理嵌套的 text_config
text_config = config.get("text_config", config)
# 提取基本参数
result = {
"architectures": config.get("architectures", []),
"model_type": config.get("model_type", "unknown"),
}
# 专家数量(不同模型字段名不同)
num_experts = (
text_config.get("num_experts") # Qwen2/3 MoE, DeepSeek V2
or text_config.get("num_local_experts") # Mixtral
or text_config.get("n_routed_experts") # Kimi-K2, DeepSeek V3
or config.get("moe", {}).get("num_experts") # Mistral 格式
)
# 每个 token 激活的专家数
num_experts_per_tok = (
text_config.get("num_experts_per_tok")
or text_config.get("num_experts_per_token")
or config.get("moe", {}).get("num_experts_per_tok")
or 2 # 默认值
)
# 层数
num_hidden_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or 0
# 隐藏层维度
hidden_size = text_config.get("hidden_size") or text_config.get("d_model") or 0
# MoE 专家中间层大小
moe_intermediate_size = (
text_config.get("moe_intermediate_size")
or text_config.get("intermediate_size") # 如果没有特殊的 moe_intermediate_size
or 0
)
# 共享专家中间层大小Qwen2/3 MoE
shared_expert_intermediate_size = text_config.get("shared_expert_intermediate_size", 0)
result.update(
{
"num_experts": num_experts or 0,
"num_experts_per_tok": num_experts_per_tok,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
"moe_intermediate_size": moe_intermediate_size,
"shared_expert_intermediate_size": shared_expert_intermediate_size,
}
)
# 提取其他有用的参数
result["num_attention_heads"] = text_config.get("num_attention_heads", 0)
result["num_key_value_heads"] = text_config.get("num_key_value_heads", 0)
result["vocab_size"] = text_config.get("vocab_size", 0)
result["max_position_embeddings"] = text_config.get("max_position_embeddings", 0)
return result
def _estimate_model_size(model_path: Path) -> float:
"""估算模型总大小GB
快速统计 safetensors 文件总大小
"""
try:
total_size = 0
for file_path in model_path.glob("*.safetensors"):
total_size += file_path.stat().st_size
return total_size / (1024**3)
except Exception:
return 0.0
def analyze_moe_model(model_path, use_cache=True):
"""
快速分析 MoE 模型 - 只读取 config.json
参数:
model_path: 模型路径字符串或Path对象
use_cache: 是否使用缓存默认True
返回:
dict: {
'is_moe': 是否是 MoE 模型,
'num_experts': 专家总数,
'num_experts_per_tok': 每个 token 激活的专家数,
'num_hidden_layers': 层数,
'hidden_size': 隐藏层维度,
'moe_intermediate_size': MoE 专家中间层大小,
'shared_expert_intermediate_size': 共享专家中间层大小,
'architectures': 模型架构列表,
'model_type': 模型类型,
'total_size_gb': 模型总大小估算GB,
'cached': 是否从缓存读取
}
如果不是 MoE 模型或失败,返回 None
"""
model_path = Path(model_path)
if not model_path.exists():
return None
# 尝试加载缓存
if use_cache:
cached_result = _load_cache(model_path)
if cached_result:
cached_result["cached"] = True
return cached_result
# 读取 config.json
config = _load_config_json(model_path)
if not config:
return None
# 判断是否是 MoE 模型
if not _is_moe_model(config):
return None
# 提取 MoE 参数
params = _extract_moe_params(config)
# 验证必要参数
if params["num_experts"] == 0:
return None
# 估算模型大小
total_size_gb = _estimate_model_size(model_path)
# 组装结果
result = {
"is_moe": True,
"num_experts": params["num_experts"],
"num_experts_per_tok": params["num_experts_per_tok"],
"num_hidden_layers": params["num_hidden_layers"],
"hidden_size": params["hidden_size"],
"moe_intermediate_size": params["moe_intermediate_size"],
"shared_expert_intermediate_size": params["shared_expert_intermediate_size"],
"architectures": params["architectures"],
"model_type": params["model_type"],
"total_size_gb": total_size_gb,
"cached": False,
# 额外参数
"num_attention_heads": params.get("num_attention_heads", 0),
"num_key_value_heads": params.get("num_key_value_heads", 0),
"vocab_size": params.get("vocab_size", 0),
}
# 保存缓存
if use_cache:
_save_cache(model_path, result)
return result
def print_analysis(model_path):
"""打印模型分析结果"""
print(f"分析模型: {model_path}\n")
result = analyze_moe_model(model_path)
if result is None:
print("不是 MoE 模型或分析失败")
return
print("=" * 70)
print("MoE 模型分析结果")
if result.get("cached"):
print("[使用缓存]")
print("=" * 70)
print(f"模型架构:")
print(f" - 架构: {', '.join(result['architectures'])}")
print(f" - 类型: {result['model_type']}")
print()
print(f"MoE 结构:")
print(f" - 专家总数: {result['num_experts']}")
print(f" - 激活专家数: {result['num_experts_per_tok']} experts/token")
print(f" - 层数: {result['num_hidden_layers']}")
print(f" - 隐藏维度: {result['hidden_size']}")
print(f" - MoE 中间层: {result['moe_intermediate_size']}")
if result["shared_expert_intermediate_size"] > 0:
print(f" - 共享专家中间层: {result['shared_expert_intermediate_size']}")
print()
print(f"大小统计:")
print(f" - 模型总大小: {result['total_size_gb']:.2f} GB")
print("=" * 70)
print()
def main():
import sys
models = ["/mnt/data2/models/Qwen3-30B-A3B", "/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507"]
if len(sys.argv) > 1:
models = [sys.argv[1]]
for model_path in models:
print_analysis(model_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,118 @@
"""
Debug utility to inspect saved run configurations.
Usage: python -m kt_kernel.cli.utils.debug_configs
"""
from pathlib import Path
import yaml
from rich.console import Console
from rich.table import Table
from rich import box
console = Console()
def main():
"""Show all saved configurations."""
config_file = Path.home() / ".ktransformers" / "run_configs.yaml"
console.print()
console.print(f"[bold]Configuration file:[/bold] {config_file}")
console.print()
if not config_file.exists():
console.print("[red]✗ Configuration file does not exist![/red]")
console.print()
console.print("No configurations have been saved yet.")
return
try:
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
except Exception as e:
console.print(f"[red]✗ Failed to load configuration file: {e}[/red]")
return
console.print(f"[green]✓[/green] Configuration file loaded")
console.print()
configs = data.get("configs", {})
if not configs:
console.print("[yellow]No saved configurations found.[/yellow]")
return
console.print(f"[bold]Found configurations for {len(configs)} model(s):[/bold]")
console.print()
for model_id, model_configs in configs.items():
console.print(f"[cyan]Model ID:[/cyan] {model_id}")
console.print(f"[dim] {len(model_configs)} configuration(s)[/dim]")
console.print()
if not model_configs:
continue
# Display configs in a table
table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan")
table.add_column("#", justify="right", style="cyan")
table.add_column("Name", style="white")
table.add_column("Method", style="yellow")
table.add_column("TP", justify="right", style="green")
table.add_column("GPU Experts", justify="right", style="magenta")
table.add_column("Created", style="dim")
for i, cfg in enumerate(model_configs, 1):
method = cfg.get("inference_method", "?")
kt_method = cfg.get("kt_method", "?")
method_display = f"{method.upper()}"
if method == "raw":
method_display += f" ({cfg.get('raw_method', '?')})"
elif method == "amx":
method_display += f" ({kt_method})"
table.add_row(
str(i),
cfg.get("config_name", f"Config {i}"),
method_display,
str(cfg.get("tp_size", "?")),
str(cfg.get("gpu_experts", "?")),
cfg.get("created_at", "Unknown")[:19] if cfg.get("created_at") else "Unknown",
)
console.print(table)
console.print()
# Also check user_models.yaml to show model names
console.print("[bold]Checking model registry...[/bold]")
console.print()
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
try:
registry = UserModelRegistry()
all_models = registry.list_models()
console.print(f"[green]✓[/green] Found {len(all_models)} registered model(s)")
console.print()
# Map model IDs to names
id_to_name = {m.id: m.name for m in all_models}
console.print("[bold]Model ID → Name mapping:[/bold]")
console.print()
for model_id in configs.keys():
model_name = id_to_name.get(model_id, "[red]Unknown (model not found in registry)[/red]")
console.print(f" {model_id[:8]}... → {model_name}")
console.print()
except Exception as e:
console.print(f"[yellow]⚠ Could not load model registry: {e}[/yellow]")
console.print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,146 @@
"""Helper functions for interactive model download."""
from pathlib import Path
from typing import Dict, List, Tuple
import fnmatch
def list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]:
"""
List files in a HuggingFace repository.
Returns:
List of dicts with keys: 'path', 'size' (in bytes)
"""
from huggingface_hub import HfApi
import os
# Set mirror if needed
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
try:
api = HfApi()
files_info = api.list_repo_tree(repo_id=repo_id, recursive=True)
result = []
for item in files_info:
# Skip directories
if hasattr(item, "type") and item.type == "directory":
continue
# Get file info
file_path = item.path if hasattr(item, "path") else str(item)
file_size = item.size if hasattr(item, "size") else 0
result.append({"path": file_path, "size": file_size})
return result
finally:
# Restore original endpoint
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
def list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]:
"""
List files in a ModelScope repository.
Returns:
List of dicts with keys: 'path', 'size' (in bytes)
"""
from modelscope.hub.api import HubApi
api = HubApi()
files_info = api.get_model_files(model_id=repo_id, recursive=True)
result = []
for file_info in files_info:
file_path = file_info.get("Name", file_info.get("Path", ""))
file_size = file_info.get("Size", 0)
result.append({"path": file_path, "size": file_size})
return result
def filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]:
"""Filter files by glob pattern."""
if pattern == "*":
return files
filtered = []
for file in files:
# Check if filename matches pattern
filename = Path(file["path"]).name
full_path = file["path"]
if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern):
filtered.append(file)
return filtered
def calculate_total_size(files: List[Dict[str, any]]) -> int:
"""Calculate total size of files in bytes."""
return sum(f["size"] for f in files)
def format_file_list_table(files: List[Dict[str, any]], max_display: int = 10):
"""Format file list as a table for display."""
from rich.table import Table
from kt_kernel.cli.utils.model_scanner import format_size
table = Table(show_header=True, header_style="bold")
table.add_column("File", style="cyan", overflow="fold")
table.add_column("Size", justify="right")
# Show first max_display files
for file in files[:max_display]:
table.add_row(file["path"], format_size(file["size"]))
if len(files) > max_display:
table.add_row(f"... and {len(files) - max_display} more files", "[dim]...[/dim]")
return table
def verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]:
"""
Verify if a repository exists.
Returns:
(exists: bool, message: str)
"""
try:
if repo_type == "huggingface":
import os
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from huggingface_hub import HfApi
try:
api = HfApi()
api.repo_info(repo_id=repo_id, repo_type="model")
return True, "Repository found"
finally:
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
else: # modelscope
from modelscope.hub.api import HubApi
api = HubApi()
api.get_model(model_id=repo_id)
return True, "Repository found"
except Exception as e:
return False, f"Repository not found: {str(e)}"

View File

@@ -0,0 +1,216 @@
"""
Input validation utilities with retry mechanism.
Provides robust input validation with automatic retry on failure.
"""
from typing import Optional, List, Callable, Any
from rich.console import Console
from rich.prompt import Prompt
console = Console()
def prompt_int_with_retry(
message: str,
default: Optional[int] = None,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
validator: Optional[Callable[[int], bool]] = None,
validator_error_msg: Optional[str] = None,
) -> int:
"""Prompt for integer input with validation and retry.
Args:
message: Prompt message
default: Default value (optional)
min_val: Minimum allowed value (optional)
max_val: Maximum allowed value (optional)
validator: Custom validation function (optional)
validator_error_msg: Error message for custom validator (optional)
Returns:
Validated integer value
"""
while True:
# Build prompt with default
if default is not None:
prompt_text = f"{message} [{default}]"
else:
prompt_text = message
# Get input
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
# Try to parse as integer
try:
value = int(user_input)
except ValueError:
console.print(f"[red]✗ Invalid input. Please enter a valid integer.[/red]")
console.print()
continue
# Validate range
if min_val is not None and value < min_val:
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
console.print()
continue
if max_val is not None and value > max_val:
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
console.print()
continue
# Custom validation
if validator is not None:
if not validator(value):
error_msg = validator_error_msg or "Invalid value"
console.print(f"[red]✗ {error_msg}[/red]")
console.print()
continue
# All validations passed
return value
def prompt_float_with_retry(
message: str,
default: Optional[float] = None,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
) -> float:
"""Prompt for float input with validation and retry.
Args:
message: Prompt message
default: Default value (optional)
min_val: Minimum allowed value (optional)
max_val: Maximum allowed value (optional)
Returns:
Validated float value
"""
while True:
# Build prompt with default
if default is not None:
prompt_text = f"{message} [{default}]"
else:
prompt_text = message
# Get input
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
# Try to parse as float
try:
value = float(user_input)
except ValueError:
console.print(f"[red]✗ Invalid input. Please enter a valid number.[/red]")
console.print()
continue
# Validate range
if min_val is not None and value < min_val:
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
console.print()
continue
if max_val is not None and value > max_val:
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
console.print()
continue
# All validations passed
return value
def prompt_choice_with_retry(
message: str,
choices: List[str],
default: Optional[str] = None,
) -> str:
"""Prompt for choice input with validation and retry.
Args:
message: Prompt message
choices: List of valid choices
default: Default choice (optional)
Returns:
Selected choice
"""
while True:
# Get input
user_input = Prompt.ask(message, default=default)
# Validate choice
if user_input not in choices:
console.print(f"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]")
console.print()
continue
return user_input
def prompt_int_list_with_retry(
message: str,
default: Optional[str] = None,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None,
) -> List[int]:
"""Prompt for comma-separated integer list with validation and retry.
Args:
message: Prompt message
default: Default value as string (e.g., "0,1,2,3")
min_val: Minimum allowed value for each integer (optional)
max_val: Maximum allowed value for each integer (optional)
validator: Custom validation function that returns (is_valid, error_message) (optional)
Returns:
List of validated integers
"""
while True:
# Get input
user_input = Prompt.ask(message, default=default)
# Clean input: support Chinese comma and spaces
user_input_cleaned = user_input.replace("", ",").replace(" ", "")
# Try to parse as integers
try:
values = [int(x.strip()) for x in user_input_cleaned.split(",") if x.strip()]
except ValueError:
console.print(f"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]")
console.print()
continue
# Validate each value's range
invalid_values = []
for value in values:
if min_val is not None and value < min_val:
invalid_values.append(value)
elif max_val is not None and value > max_val:
invalid_values.append(value)
if invalid_values:
if min_val is not None and max_val is not None:
console.print(f"[red]✗ Invalid value(s): {invalid_values}[/red]")
console.print(f"[yellow]Valid range: {min_val}-{max_val}[/yellow]")
elif min_val is not None:
console.print(f"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]")
elif max_val is not None:
console.print(f"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]")
console.print()
continue
# Custom validation
if validator is not None:
is_valid, error_msg = validator(values)
if not is_valid:
console.print(f"[red]✗ {error_msg}[/red]")
console.print()
continue
# All validations passed
return values

View File

@@ -0,0 +1,207 @@
#!/usr/bin/env python3
"""
KV Cache Size Calculator for SGLang
This script calculates the KV cache size in GB for a given model and number of tokens.
It follows the same logic as in sglang/srt/model_executor/model_runner.py
"""
import os
import sys
import torch
from transformers import AutoConfig
# Add sglang to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
def get_dtype_bytes(dtype_str: str) -> int:
"""Get the number of bytes for a given dtype string."""
dtype_map = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"float8_e4m3fn": 1,
"float8_e5m2": 1,
"auto": 2, # Usually defaults to bfloat16
}
return dtype_map.get(dtype_str, 2)
def get_kv_size_gb(
model_path: str,
max_total_tokens: int,
tp: int = 1,
dtype: str = "auto",
verbose: bool = True,
) -> dict:
"""
Calculate the KV cache size in GB for a given model and number of tokens.
Args:
model_path: Path to the model
max_total_tokens: Maximum number of tokens to cache
tp: Tensor parallelism size
dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)
verbose: Whether to print detailed information
Returns:
dict: Dictionary containing calculation details
"""
# Load model config
model_config = ModelConfig(model_path, dtype=dtype)
hf_config = model_config.hf_config
# Determine dtype bytes
dtype_bytes = get_dtype_bytes(dtype)
if dtype == "auto":
# Auto dtype usually becomes bfloat16
dtype_bytes = 2
# Number of layers
num_layers = model_config.num_attention_layers
# Check if it's MLA (Multi-head Latent Attention) model
is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA"
result = {
"model_path": model_path,
"max_total_tokens": max_total_tokens,
"tp": tp,
"dtype": dtype,
"dtype_bytes": dtype_bytes,
"num_layers": num_layers,
"is_mla": is_mla,
}
if is_mla:
# MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)
kv_lora_rank = model_config.kv_lora_rank
qk_rope_head_dim = model_config.qk_rope_head_dim
# Calculate cell size (per token)
cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes
result.update(
{
"kv_lora_rank": kv_lora_rank,
"qk_rope_head_dim": qk_rope_head_dim,
"cell_size_bytes": cell_size,
}
)
# Check if it's NSA (Native Sparse Attention) model
if is_deepseek_nsa(hf_config):
index_head_dim = get_nsa_index_head_dim(hf_config)
indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4
indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)
indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes
cell_size += indexer_cell_size
result.update(
{
"is_nsa": True,
"index_head_dim": index_head_dim,
"indexer_cell_size_bytes": indexer_cell_size,
"total_cell_size_bytes": cell_size,
}
)
else:
result["is_nsa"] = False
else:
# Standard MHA models
num_kv_heads = model_config.get_num_kv_heads(tp)
head_dim = model_config.head_dim
v_head_dim = model_config.v_head_dim
# Calculate cell size (per token)
cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes
result.update(
{
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"v_head_dim": v_head_dim,
"cell_size_bytes": cell_size,
}
)
# Calculate total KV cache size
total_size_bytes = max_total_tokens * cell_size
total_size_gb = total_size_bytes / (1024**3)
# For MHA models with separate K and V buffers
if not is_mla:
k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes
v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes
k_size_gb = k_size_bytes / (1024**3)
v_size_gb = v_size_bytes / (1024**3)
result.update(
{
"k_size_gb": k_size_gb,
"v_size_gb": v_size_gb,
}
)
result.update(
{
"total_size_bytes": total_size_bytes,
"total_size_gb": total_size_gb,
}
)
if verbose:
print(f"Model: {model_path}")
print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}")
print(f"Architecture: {'MLA' if is_mla else 'MHA'}")
print(f"Layers: {num_layers}")
if is_mla:
print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}")
if result.get("is_nsa"):
print(f"NSA Index Head Dim: {index_head_dim}")
print(
f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})"
)
else:
print(f"Cell size: {cell_size} bytes")
else:
print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}")
print(f"Cell size: {cell_size} bytes")
print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB")
print(f"Total KV Cache Size: {total_size_gb:.2f} GB")
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="Calculate KV cache size for a model")
parser.add_argument("model_path", help="Path to the model")
parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)")
parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")
args = parser.parse_args()
result = get_kv_size_gb(
args.model_path,
args.max_total_tokens,
tp=args.tp,
dtype=args.dtype,
verbose=not args.quiet,
)
if args.quiet:
print(f"{result['total_size_gb']:.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,250 @@
"""
Model Discovery Utilities
Shared functions for discovering and registering new models across different commands.
"""
from typing import List, Optional, Tuple
from pathlib import Path
from rich.console import Console
from kt_kernel.cli.utils.model_scanner import (
discover_models,
scan_directory_for_models,
ScannedModel,
)
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel
console = Console()
def discover_and_register_global(
min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = "en"
) -> Tuple[int, int, List[UserModel]]:
"""
Perform global model discovery and register new models.
Args:
min_size_gb: Minimum model size in GB
max_depth: Maximum search depth
show_progress: Whether to show progress messages
lang: Language for messages ("en" or "zh")
Returns:
Tuple of (total_found, new_found, registered_models)
"""
registry = UserModelRegistry()
if show_progress:
if lang == "zh":
console.print("[dim]正在扫描系统中的模型权重这可能需要30-60秒...[/dim]")
else:
console.print("[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]")
# Global scan
all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth)
# Filter out existing models
new_models = []
for model in all_models:
if not registry.find_by_path(model.path):
new_models.append(model)
# Register new models
registered = []
for model in new_models:
user_model = _create_and_register_model(registry, model)
if user_model:
registered.append(user_model)
return len(all_models), len(new_models), registered
def discover_and_register_path(
path: str,
min_size_gb: float = 2.0,
existing_paths: Optional[set] = None,
show_progress: bool = True,
lang: str = "en",
) -> Tuple[int, int, List[UserModel]]:
"""
Discover models in a specific path and register new ones.
Args:
path: Directory path to scan
min_size_gb: Minimum model file size in GB
existing_paths: Set of already discovered paths in this session (optional)
show_progress: Whether to show progress messages
lang: Language for messages ("en" or "zh")
Returns:
Tuple of (total_found, new_found, registered_models)
"""
registry = UserModelRegistry()
if show_progress:
if lang == "zh":
console.print(f"[dim]正在扫描 {path}...[/dim]")
else:
console.print(f"[dim]Scanning {path}...[/dim]")
# Scan directory
model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb)
if not model_info:
return 0, 0, []
# Convert to ScannedModel and filter
new_models = []
for dir_path, (format_type, size_bytes, file_count, files) in model_info.items():
# Check if already in registry
if registry.find_by_path(dir_path):
continue
# Check if already discovered in this session
if existing_paths and dir_path in existing_paths:
continue
model = ScannedModel(
path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files
)
new_models.append(model)
# Register new models
registered = []
for model in new_models:
user_model = _create_and_register_model(registry, model)
if user_model:
registered.append(user_model)
return len(model_info), len(new_models), registered
def _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]:
"""
Create a UserModel from ScannedModel and register it.
Handles name conflicts by suggesting a unique name (e.g., model-2, model-3).
Automatically detects repo_id from README.md YAML frontmatter.
Automatically detects and caches MoE information for safetensors models.
Args:
registry: UserModelRegistry instance
scanned_model: ScannedModel to register
Returns:
Registered UserModel or None if failed
"""
# Use suggest_name to get a unique name (adds -2, -3, etc. if needed)
unique_name = registry.suggest_name(scanned_model.folder_name)
user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format)
# Auto-detect repo_id from README.md (only YAML frontmatter)
try:
from kt_kernel.cli.utils.repo_detector import detect_repo_for_model
repo_info = detect_repo_for_model(scanned_model.path)
if repo_info:
repo_id, repo_type = repo_info
user_model.repo_id = repo_id
user_model.repo_type = repo_type
except Exception:
# Silently continue if detection fails
pass
# Auto-detect MoE information for safetensors models
if scanned_model.format == "safetensors":
try:
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
moe_result = analyze_moe_model(scanned_model.path, use_cache=True)
if moe_result and moe_result.get("is_moe"):
user_model.is_moe = True
user_model.moe_num_experts = moe_result.get("num_experts")
user_model.moe_num_experts_per_tok = moe_result.get("num_experts_per_tok")
else:
user_model.is_moe = False
except Exception:
# Silently continue if MoE detection fails
# is_moe will remain None
pass
try:
registry.add_model(user_model)
return user_model
except Exception:
# Should not happen since we used suggest_name, but handle gracefully
return None
def format_discovery_summary(
total_found: int,
new_found: int,
registered: List[UserModel],
lang: str = "en",
show_models: bool = True,
max_show: int = 10,
) -> None:
"""
Print formatted discovery summary.
Args:
total_found: Total models found
new_found: New models found
registered: List of registered UserModel objects
lang: Language ("en" or "zh")
show_models: Whether to show model list
max_show: Maximum models to show
"""
console.print()
if new_found == 0:
if total_found > 0:
if lang == "zh":
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,所有模型均已在列表中")
else:
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, all already in the list")
else:
if lang == "zh":
console.print("[yellow]未找到模型[/yellow]")
else:
console.print("[yellow]No models found[/yellow]")
return
# Show summary
if lang == "zh":
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,其中 {new_found} 个为新模型")
else:
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new")
# Show registered count
if len(registered) > 0:
if lang == "zh":
console.print(f"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表")
else:
console.print(f"[green]✓[/green] Successfully added {len(registered)} new models to list")
# Show model list
if show_models and registered:
console.print()
if lang == "zh":
console.print(f"[dim]新发现的模型(前{max_show}个):[/dim]")
else:
console.print(f"[dim]Newly discovered models (first {max_show}):[/dim]")
for i, model in enumerate(registered[:max_show], 1):
# Get size from registry or estimate
size_str = "?.? GB"
# Try to find the ScannedModel to get size
# For now just show name and path
console.print(f" {i}. {model.name} ({model.format})")
console.print(f" [dim]{model.path}[/dim]")
if len(registered) > max_show:
remaining = len(registered) - max_show
if lang == "zh":
console.print(f" [dim]... 还有 {remaining} 个新模型[/dim]")
else:
console.print(f" [dim]... and {remaining} more new models[/dim]")

View File

@@ -0,0 +1,790 @@
"""
Model Scanner
Scans directories for model files (safetensors, gguf) and identifies models
"""
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set, Tuple, Dict
from collections import defaultdict
import os
import subprocess
import json
@dataclass
class ScannedModel:
"""Temporary structure for scanned model information"""
path: str # Absolute path to model directory
format: str # "safetensors" | "gguf" | "mixed"
size_bytes: int # Total size in bytes
file_count: int # Number of model files
files: List[str] # List of model file names
@property
def size_gb(self) -> float:
"""Get size in GB"""
return self.size_bytes / (1024**3)
@property
def folder_name(self) -> str:
"""Get the folder name (default model name)"""
return Path(self.path).name
class ModelScanner:
"""Scanner for discovering models in directory trees"""
def __init__(self, min_size_gb: float = 10.0):
"""
Initialize scanner
Args:
min_size_gb: Minimum folder size in GB to be considered a model
"""
self.min_size_bytes = int(min_size_gb * 1024**3)
def scan_directory(
self, base_path: Path, exclude_paths: Optional[Set[str]] = None
) -> Tuple[List[ScannedModel], List[str]]:
"""
Scan directory tree for models
Args:
base_path: Root directory to scan
exclude_paths: Set of absolute paths to exclude from results
Returns:
Tuple of (valid_models, warnings)
- valid_models: List of ScannedModel instances
- warnings: List of warning messages
"""
if not base_path.exists():
raise ValueError(f"Path does not exist: {base_path}")
if not base_path.is_dir():
raise ValueError(f"Path is not a directory: {base_path}")
exclude_paths = exclude_paths or set()
results: List[ScannedModel] = []
warnings: List[str] = []
# Walk the directory tree
for root, dirs, files in os.walk(base_path):
root_path = Path(root).resolve()
# Skip if already registered
if str(root_path) in exclude_paths:
dirs[:] = [] # Don't descend into this directory
continue
# Check for model files
safetensors_files = [f for f in files if f.endswith(".safetensors")]
gguf_files = [f for f in files if f.endswith(".gguf")]
if not safetensors_files and not gguf_files:
continue # No model files in this directory
# Calculate total size
model_files = safetensors_files + gguf_files
total_size = self._calculate_total_size(root_path, model_files)
# Check if size meets minimum threshold
if total_size < self.min_size_bytes:
continue # Too small, but keep scanning subdirectories
# Detect format
if safetensors_files and gguf_files:
# Mixed format - issue warning
warnings.append(
f"Mixed format detected in {root_path}: "
f"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. "
"Please separate into different folders and re-scan."
)
dirs[:] = [] # Don't descend into mixed format directories
continue
# Determine format
format_type = "safetensors" if safetensors_files else "gguf"
# Create scanned model
scanned = ScannedModel(
path=str(root_path),
format=format_type,
size_bytes=total_size,
file_count=len(model_files),
files=model_files,
)
results.append(scanned)
# Continue scanning subdirectories - they might also contain models
# Each subdirectory will be independently checked for size >= 10GB
return results, warnings
def scan_single_path(self, path: Path) -> Optional[ScannedModel]:
"""
Scan a single path for model files
Args:
path: Path to scan
Returns:
ScannedModel instance or None if not a valid model
"""
if not path.exists() or not path.is_dir():
return None
# Find model files
safetensors_files = list(path.glob("*.safetensors"))
gguf_files = list(path.glob("*.gguf"))
if not safetensors_files and not gguf_files:
return None
# Check for mixed format
if safetensors_files and gguf_files:
raise ValueError(
f"Mixed format detected: {len(safetensors_files)} safetensors + "
f"{len(gguf_files)} gguf files. Please use a single format."
)
# Calculate size
model_files = [f.name for f in safetensors_files + gguf_files]
total_size = self._calculate_total_size(path, model_files)
# Determine format
format_type = "safetensors" if safetensors_files else "gguf"
return ScannedModel(
path=str(path.resolve()),
format=format_type,
size_bytes=total_size,
file_count=len(model_files),
files=model_files,
)
def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int:
"""
Calculate total size of specified files in directory
Args:
directory: Directory containing the files
filenames: List of filenames to sum
Returns:
Total size in bytes
"""
total = 0
for filename in filenames:
file_path = directory / filename
if file_path.exists():
try:
total += file_path.stat().st_size
except OSError:
# File might be inaccessible, skip it
pass
return total
# Convenience functions
def scan_directory(
base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None
) -> Tuple[List[ScannedModel], List[str]]:
"""
Convenience function to scan a directory
Args:
base_path: Root directory to scan
min_size_gb: Minimum folder size in GB
exclude_paths: Set of paths to exclude
Returns:
Tuple of (models, warnings)
"""
scanner = ModelScanner(min_size_gb=min_size_gb)
return scanner.scan_directory(base_path, exclude_paths)
def scan_single_path(path: Path) -> Optional[ScannedModel]:
"""
Convenience function to scan a single path
Args:
path: Path to scan
Returns:
ScannedModel or None
"""
scanner = ModelScanner()
return scanner.scan_single_path(path)
def format_size(size_bytes: int) -> str:
"""
Format size in bytes to human-readable string
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., "42.3 GB")
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} PB"
# ===== Fast Scanning with Find Command and Tree-based Root Detection =====
def find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]:
"""
Use find command to quickly locate files
Args:
mount_point: Starting directory
pattern: File pattern (e.g., "config.json", "*.gguf")
max_depth: Maximum directory depth (default: 6)
timeout: Command timeout in seconds
Returns:
List of absolute file paths
"""
try:
# Use shell=True to redirect stderr to /dev/null, ignoring permission errors
result = subprocess.run(
f'find "{mount_point}" -maxdepth {max_depth} -name "{pattern}" -type f 2>/dev/null',
capture_output=True,
text=True,
timeout=timeout,
shell=True,
)
# Return results even if returncode is non-zero (due to permission errors)
# As long as we got some output
if result.stdout:
return [line.strip() for line in result.stdout.strip().split("\n") if line.strip()]
return []
except (subprocess.TimeoutExpired, FileNotFoundError):
return []
def is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]:
"""
Check if a directory is a valid model directory
Args:
directory: Path to check
min_size_gb: Minimum size in GB
Returns:
(is_valid, model_type) where model_type is "safetensors", "gguf", or None
"""
if not directory.exists() or not directory.is_dir():
return False, None
has_config = (directory / "config.json").exists()
safetensors_files = list(directory.glob("*.safetensors"))
gguf_files = list(directory.glob("*.gguf"))
# Determine model type
model_type = None
if (has_config and safetensors_files) or safetensors_files:
model_type = "safetensors"
elif gguf_files:
model_type = "gguf"
else:
return False, None
# Check size - only count model files (fast!)
total_size = 0
if model_type == "safetensors":
for f in safetensors_files:
try:
total_size += f.stat().st_size
except OSError:
pass
else: # gguf
for f in gguf_files:
try:
total_size += f.stat().st_size
except OSError:
pass
size_gb = total_size / (1024**3)
if size_gb < min_size_gb:
return False, None
return True, model_type
def scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]:
"""
Fast scan for all model paths using find command
Args:
mount_points: List of mount points to scan
min_size_gb: Minimum model size in GB
max_depth: Maximum search depth (default: 6)
Returns:
List of valid model directory paths
"""
model_paths = set()
for mount in mount_points:
if not os.path.exists(mount):
continue
# Find all config.json files
config_files = find_files_fast(mount, "config.json", max_depth=max_depth)
for config_path in config_files:
model_dir = Path(config_path).parent
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
if is_valid:
model_paths.add(str(model_dir.resolve()))
# Find all *.gguf files
gguf_files = find_files_fast(mount, "*.gguf", max_depth=max_depth)
for gguf_path in gguf_files:
model_dir = Path(gguf_path).parent
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
if is_valid:
model_paths.add(str(model_dir.resolve()))
return sorted(model_paths)
def get_root_subdirs() -> List[str]:
"""
Get subdirectories of / that are worth scanning
Filters out system paths only
Returns:
List of directories to scan
"""
# System paths to exclude
excluded = {
"dev",
"proc",
"sys",
"run",
"boot",
"tmp",
"usr",
"lib",
"lib64",
"bin",
"sbin",
"etc",
"opt",
"var",
"snap",
}
scan_dirs = []
try:
for entry in os.scandir("/"):
if not entry.is_dir():
continue
# Skip excluded paths
if entry.name in excluded:
continue
scan_dirs.append(entry.path)
except PermissionError:
pass
return sorted(scan_dirs)
def scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]:
"""
Scan a directory for models using find command with size filter
Uses find -size +2G to only locate large model files (>=2GB)
Args:
directory: Directory to scan
min_file_size_gb: Minimum individual file size in GB (default: 2.0)
Returns:
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
"""
model_info = {}
# Convert GB to find's format (e.g., 2GB = +2G)
if min_file_size_gb >= 1.0:
size_filter = f"+{int(min_file_size_gb)}G"
else:
size_mb = int(min_file_size_gb * 1024)
size_filter = f"+{size_mb}M"
# 1. Find *.gguf files >= 2GB
gguf_cmd = f'find "{directory}" -name "*.gguf" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
result = subprocess.run(gguf_cmd, shell=True, capture_output=True, text=True, timeout=120)
# Group by directory
gguf_dirs = defaultdict(list)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
file_path, size_str = parts
file_path_obj = Path(file_path)
dir_path = str(file_path_obj.parent)
gguf_dirs[dir_path].append((file_path_obj.name, int(size_str)))
# Add all gguf directories
for dir_path, files in gguf_dirs.items():
total_size = sum(size for _, size in files)
model_info[dir_path] = ("gguf", total_size, len(files), [name for name, _ in files])
# 2. Find *.safetensors files >= 2GB
safetensors_cmd = (
f'find "{directory}" -name "*.safetensors" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
)
result = subprocess.run(safetensors_cmd, shell=True, capture_output=True, text=True, timeout=120)
# Group by directory
safetensors_dirs = defaultdict(list)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
file_path, size_str = parts
file_path_obj = Path(file_path)
dir_path = str(file_path_obj.parent)
safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str)))
# 3. Check each safetensors directory for config.json
for dir_path, files in safetensors_dirs.items():
if os.path.exists(os.path.join(dir_path, "config.json")):
total_size = sum(size for _, size in files)
model_info[dir_path] = ("safetensors", total_size, len(files), [name for name, _ in files])
return model_info
def scan_all_models_with_info(
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
) -> Dict[str, tuple]:
"""
Fast scan with parallel directory scanning
Strategy:
1. Use provided directories or auto-detect root subdirectories
2. Scan each directory in parallel (one thread per directory)
3. Use find -size +2G to find large model files (>=2GB)
Args:
mount_points: Specific directories to scan, or None to auto-detect from / subdirs
min_size_gb: Not used anymore (kept for API compatibility)
max_depth: Not used anymore (kept for API compatibility)
Returns:
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
# Get directories to scan
if mount_points is None:
# Get root subdirectories (exclude system paths)
scan_dirs = get_root_subdirs()
else:
scan_dirs = mount_points
if not scan_dirs:
return {}
model_info = {}
# Scan each directory in parallel (max 8 concurrent)
# Use 2GB threshold to find model files
with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor:
futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs}
for future in as_completed(futures):
try:
dir_results = future.result()
model_info.update(dir_results)
except Exception as e:
# Skip directories with errors
pass
return model_info
def find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]:
"""
Find optimal root paths from model paths using tree-based algorithm
Algorithm:
1. Build path tree with all intermediate paths
2. DFS to calculate f(x) = subtree sum (number of models in subtree)
3. Find roots where f(parent) = f(x) > max(f(children))
Args:
model_paths: List of model directory paths
Returns:
(root_paths, subtree_sizes) where:
- root_paths: List of inferred root directories
- subtree_sizes: Dict mapping each root to number of models
"""
if not model_paths:
return [], {}
# 1. Build path set (including all intermediate paths)
all_paths = set()
model_set = set(model_paths)
for model_path in model_paths:
path = Path(model_path)
for i in range(1, len(path.parts) + 1):
all_paths.add(str(Path(*path.parts[:i])))
# 2. Build parent-child relationships
children_map = defaultdict(list)
for path in all_paths:
path_obj = Path(path)
if len(path_obj.parts) > 1:
parent = str(path_obj.parent)
if parent in all_paths:
children_map[parent].append(path)
# 3. DFS to calculate f(x) and max_child_f(x)
f = {} # path -> subtree sum
max_child_f = {} # path -> max(f(children))
visited = set()
def dfs(path: str) -> int:
if path in visited:
return f[path]
visited.add(path)
# Current node weight (1 if it's a model path, 0 otherwise)
weight = 1 if path in model_set else 0
# Recursively calculate children
children = children_map.get(path, [])
if not children:
# Leaf node
f[path] = weight
max_child_f[path] = 0
return weight
# Calculate f values for all children
children_f_values = [dfs(child) for child in children]
# Calculate f(x) and max_child_f(x)
f[path] = weight + sum(children_f_values)
max_child_f[path] = max(children_f_values) if children_f_values else 0
return f[path]
# Find top-level nodes (no parent in all_paths)
top_nodes = []
for path in all_paths:
parent = str(Path(path).parent)
if parent not in all_paths or parent == path:
top_nodes.append(path)
# Execute DFS from all top nodes
for top in top_nodes:
dfs(top)
# 4. Find root nodes: f(parent) = f(x) >= max(f(children))
# Note: Use >= instead of > to handle the case where a directory contains only one model
candidate_roots = []
for path in all_paths:
# Skip model paths themselves (leaf nodes in model tree)
if path in model_set:
continue
parent = str(Path(path).parent)
# Check condition: f(parent) = f(x) and f(x) >= max(f(children))
if parent in f and f.get(parent, 0) == f.get(path, 0):
if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0:
candidate_roots.append(path)
# 5. Remove redundant roots (prefer deeper paths)
# If a root is an ancestor of another root with the same f value, remove it
roots = []
candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts))
for root in candidate_roots_sorted:
# Check if this root is a parent of any already selected root
is_redundant = False
for selected in roots:
if selected.startswith(root + "/"):
# selected is a child of root
# Only keep root if it has more models (shouldn't happen by algorithm)
if f.get(root, 0) == f.get(selected, 0):
is_redundant = True
break
if not is_redundant:
# Also filter out very shallow paths (< 3 levels)
if len(Path(root).parts) >= 3:
roots.append(root)
# Build subtree sizes for roots
subtree_sizes = {root: f.get(root, 0) for root in roots}
return sorted(roots), subtree_sizes
@dataclass
class ModelRootInfo:
"""Information about a detected model root path"""
path: str
model_count: int
models: List[ScannedModel]
def discover_models(
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
) -> List[ScannedModel]:
"""
Discover all model directories on the system
Fast scan using find command to locate all models that meet the criteria
Args:
mount_points: List of mount points to scan (None = auto-detect)
min_size_gb: Minimum model size in GB (default: 10.0)
max_depth: Maximum search depth (default: 6)
Returns:
List of ScannedModel sorted by path
"""
# Auto-detect mount points if not provided
if mount_points is None:
mount_points = _get_mount_points()
# Fast scan with cached info (only scan once!)
model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth)
if not model_info:
return []
# Convert to ScannedModel objects
results = []
for model_path, (model_type, total_size, file_count, files) in model_info.items():
results.append(
ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files)
)
# Sort by path
results.sort(key=lambda m: m.path)
return results
def _get_mount_points() -> List[str]:
"""
Get all valid mount points from /proc/mounts, filtering out system paths
Returns:
List of mount point paths suitable for model storage
(excludes root "/" to avoid scanning entire filesystem)
"""
mount_points = set()
# System paths to exclude (unlikely to contain model files)
excluded_paths = [
"/snap/",
"/proc/",
"/sys/",
"/run/",
"/boot",
"/dev/",
"/usr",
"/lib",
"/lib64",
"/bin",
"/sbin",
"/etc",
"/opt",
"/var",
"/tmp",
]
try:
with open("/proc/mounts", "r") as f:
for line in f:
parts = line.split()
if len(parts) < 3:
continue
device, mount_point, fs_type = parts[0], parts[1], parts[2]
# Filter out pseudo filesystems
pseudo_fs = {
"proc",
"sysfs",
"devpts",
"tmpfs",
"devtmpfs",
"cgroup",
"cgroup2",
"pstore",
"bpf",
"tracefs",
"debugfs",
"hugetlbfs",
"mqueue",
"configfs",
"securityfs",
"fuse.gvfsd-fuse",
"fusectl",
"squashfs",
"overlay", # snap packages
}
if fs_type in pseudo_fs:
continue
# Skip root directory (too large to scan)
if mount_point == "/":
continue
# Filter out system paths
if any(mount_point.startswith(x) for x in excluded_paths):
continue
# Only include if it exists and is readable
if os.path.exists(mount_point) and os.access(mount_point, os.R_OK):
mount_points.add(mount_point)
# If no mount points found, add common data directories
if not mount_points:
# Add /home if it exists and is not already a separate mount point
common_paths = ["/home", "/data", "/mnt"]
for path in common_paths:
if os.path.exists(path) and os.access(path, os.R_OK):
mount_points.add(path)
except (FileNotFoundError, PermissionError):
# Fallback to common paths
mount_points = {"/home", "/mnt", "/data"}
return sorted(mount_points)

View File

@@ -0,0 +1,254 @@
"""
Shared model table builders for consistent UI across commands.
Provides reusable table construction functions for displaying models
in kt model list, kt quant, kt run, etc.
"""
from typing import List, Optional, Tuple
from pathlib import Path
from rich.table import Table
from rich.console import Console
import json
def format_model_size(model_path: Path, format_type: str) -> str:
"""Calculate and format model size."""
from kt_kernel.cli.utils.model_scanner import format_size
try:
if format_type == "safetensors":
files = list(model_path.glob("*.safetensors"))
elif format_type == "gguf":
files = list(model_path.glob("*.gguf"))
else:
return "[dim]-[/dim]"
total_size = sum(f.stat().st_size for f in files if f.exists())
return format_size(total_size)
except Exception:
return "[dim]-[/dim]"
def format_repo_info(model) -> str:
"""Format repository information."""
if model.repo_id:
repo_abbr = "hf" if model.repo_type == "huggingface" else "ms"
return f"{repo_abbr}:{model.repo_id}"
return "[dim]-[/dim]"
def format_sha256_status(model, status_map: dict) -> str:
"""Format SHA256 verification status."""
return status_map.get(model.sha256_status or "not_checked", "[dim]?[/dim]")
def build_moe_gpu_table(
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
) -> Tuple[Table, List]:
"""
Build MoE GPU models table.
Args:
models: List of MoE GPU model objects
status_map: SHA256_STATUS_MAP for formatting status
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Exps", justify="center", style="yellow")
table.add_column("Act", justify="center", style="green")
table.add_column("Repository", style="dim", overflow="fold")
table.add_column("SHA256", justify="center")
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "safetensors")
# MoE info
num_experts = str(model.moe_num_experts) if model.moe_num_experts else "[dim]-[/dim]"
num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else "[dim]-[/dim]"
# Repository and SHA256
repo_str = format_repo_info(model)
sha256_str = format_sha256_status(model, status_map)
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str])
table.add_row(*row)
return table, displayed_models
def build_amx_table(
models: List,
status_map: dict = None, # Kept for API compatibility but not used
show_index: bool = True,
start_index: int = 1,
show_linked_gpus: bool = False,
gpu_models: Optional[List] = None,
) -> Tuple[Table, List]:
"""
Build AMX models table.
Note: AMX models are locally quantized, so no SHA256 verification column.
Args:
models: List of AMX model objects
status_map: (Unused - kept for API compatibility)
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
show_linked_gpus: Whether to show sub-rows for linked GPU models
gpu_models: List of GPU models (required if show_linked_gpus=True)
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Method", justify="center", style="yellow")
table.add_column("NUMA", justify="center", style="green")
table.add_column("Source", style="dim", overflow="fold")
# Build reverse map if needed
amx_used_by_gpu = {}
if show_linked_gpus and gpu_models:
for model in models:
if model.gpu_model_ids:
gpu_names = []
for gpu_id in model.gpu_model_ids:
for gpu_model in gpu_models:
if gpu_model.id == gpu_id:
gpu_names.append(gpu_model.name)
break
if gpu_names:
amx_used_by_gpu[model.id] = gpu_names
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "safetensors")
# Read metadata from config.json or UserModel fields
method_from_config = None
numa_from_config = None
try:
config_path = Path(model.path) / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
amx_quant = config.get("amx_quantization", {})
if amx_quant.get("converted"):
method_from_config = amx_quant.get("method")
numa_from_config = amx_quant.get("numa_count")
except Exception:
pass
# Priority: UserModel fields > config.json > ?
method_display = (
model.amx_quant_method.upper()
if model.amx_quant_method
else method_from_config.upper() if method_from_config else "[dim]?[/dim]"
)
numa_display = (
str(model.amx_numa_nodes)
if model.amx_numa_nodes
else str(numa_from_config) if numa_from_config else "[dim]?[/dim]"
)
source_display = model.amx_source_model or "[dim]-[/dim]"
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, method_display, numa_display, source_display])
table.add_row(*row)
# Add sub-row showing linked GPUs
if show_linked_gpus and model.id in amx_used_by_gpu:
gpu_list = amx_used_by_gpu[model.id]
gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list])
sub_row = []
if show_index:
sub_row.append("")
sub_row.extend([f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""])
table.add_row(*sub_row, style="dim")
return table, displayed_models
def build_gguf_table(
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
) -> Tuple[Table, List]:
"""
Build GGUF models table.
Args:
models: List of GGUF model objects
status_map: SHA256_STATUS_MAP for formatting status
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Repository", style="dim", overflow="fold")
table.add_column("SHA256", justify="center")
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "gguf")
# Repository and SHA256
repo_str = format_repo_info(model)
sha256_str = format_sha256_status(model, status_map)
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, repo_str, sha256_str])
table.add_row(*row)
return table, displayed_models

View File

@@ -0,0 +1,918 @@
"""
Model Verifier
SHA256 verification for model integrity
"""
import hashlib
import requests
import os
from pathlib import Path
from typing import Dict, Any, Literal, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:
"""
Compute SHA256 for a single file (worker function for multiprocessing).
Args:
file_path: Path to the file
Returns:
Tuple of (filename, sha256_hash, file_size_mb)
"""
sha256_hash = hashlib.sha256()
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Read file in chunks to handle large files
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks
sha256_hash.update(byte_block)
return file_path.name, sha256_hash.hexdigest(), file_size_mb
def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:
"""
Check if HuggingFace is accessible.
Args:
timeout: Connection timeout in seconds
Returns:
Tuple of (is_accessible, message)
"""
test_url = "https://huggingface.co"
try:
response = requests.head(test_url, timeout=timeout, allow_redirects=True)
if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible"
return True, "HuggingFace is accessible"
except requests.exceptions.Timeout:
return False, f"Connection to {test_url} timed out"
except requests.exceptions.ConnectionError:
return False, f"Cannot connect to {test_url}"
except requests.exceptions.RequestException as e:
return False, f"Connection error: {str(e)}"
return False, "Unknown connection error"
def verify_model_integrity(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
) -> Dict[str, Any]:
"""
Verify local model integrity against remote repository SHA256 hashes.
Verifies all important files:
- *.safetensors (weights)
- *.json (config files)
- *.py (custom model code)
Args:
repo_type: Type of repository ("huggingface" or "modelscope")
repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3")
local_dir: Local directory containing model files
progress_callback: Optional callback function(message: str) for progress updates
Returns:
Dictionary with verification results:
{
"status": "passed" | "failed" | "error",
"files_checked": int,
"files_passed": int,
"files_failed": [list of filenames],
"error_message": str (optional)
}
"""
def report_progress(msg: str):
"""Helper to report progress"""
if progress_callback:
progress_callback(msg)
try:
# Convert repo_type to platform format
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes from remote
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform)
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes with progress
report_progress(f"Calculating SHA256 for local files...")
# Get all local files matching the patterns
local_files = []
for pattern in ["*.safetensors", "*.json", "*.py"]:
local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])
if not local_files:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in local directory: {local_dir}",
}
# Calculate hashes for all files
local_hashes = calculate_local_sha256(
local_dir,
file_pattern="*.safetensors", # Unused when files_list is provided
progress_callback=report_progress,
files_list=local_files,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes with progress
report_progress(f"Comparing {len(official_hashes)} files...")
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
# Handle potential path separators in filename
file_basename = Path(filename).name
# Try to find the file in local hashes
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING")
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH")
else:
files_passed += 1
report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}")
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except ImportError as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope",
"is_network_error": False,
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
) as e:
# Network-related errors - suggest mirror
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def calculate_local_sha256(
local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None
) -> Dict[str, str]:
"""
Calculate SHA256 hashes for files in a directory using parallel processing.
Args:
local_dir: Directory to scan
file_pattern: Glob pattern for files to hash (ignored if files_list is provided)
progress_callback: Optional callback function(message: str) for progress updates
files_list: Optional pre-filtered list of files to hash (overrides file_pattern)
Returns:
Dictionary mapping filename to SHA256 hash
"""
result = {}
if not local_dir.exists():
return result
# Get all files first to report total
if files_list is not None:
files_to_hash = files_list
else:
files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]
total_files = len(files_to_hash)
if total_files == 0:
return result
# Use min(16, total_files) workers to avoid over-spawning processes
max_workers = min(16, total_files)
if progress_callback:
progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation")
# Use ProcessPoolExecutor for CPU-intensive SHA256 computation
completed_count = 0
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}
# Process results as they complete
for future in as_completed(future_to_file):
completed_count += 1
try:
filename, sha256_hash, file_size_mb = future.result()
result[filename] = sha256_hash
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)")
except Exception as e:
file_path = future_to_file[future]
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}")
return result
def fetch_model_sha256(
repo_id: str,
platform: Literal["hf", "ms"],
revision: str | None = None,
use_mirror: bool = False,
timeout: int | None = None,
) -> dict[str, str]:
"""
获取模型仓库中所有重要文件的 sha256 哈希值。
包括:
- *.safetensors (权重文件)
- *.json (配置文件config.json, tokenizer_config.json 等)
- *.py (自定义模型代码modeling.py, configuration.py 等)
Args:
repo_id: 仓库 ID例如 "Qwen/Qwen3-30B-A3B"
platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope)
revision: 版本/分支,默认 HuggingFace 为 "main"ModelScope 为 "master"
use_mirror: 是否使用镜像(仅对 HuggingFace 有效)
timeout: 网络请求超时时间None 表示不设置超时
Returns:
dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."}
"""
if platform == "hf":
# 先尝试直连,失败后自动使用镜像
try:
if use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout)
except Exception as e:
# 如果不是镜像模式且失败了,自动重试使用镜像
if not use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
raise e
elif platform == "ms":
return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout)
else:
raise ValueError(f"不支持的平台: {platform},请使用 'hf''ms'")
def _fetch_from_huggingface(
repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None
) -> dict[str, str]:
"""从 HuggingFace 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
use_mirror: 是否使用镜像hf-mirror.com
timeout: 网络请求超时时间None 表示不设置超时
"""
import os
import socket
# 如果需要使用镜像,设置环境变量
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
from huggingface_hub import HfApi, list_repo_files
try:
api = HfApi()
all_files = list_repo_files(repo_id=repo_id, revision=revision)
# 筛选重要文件:*.safetensors, *.json, *.py
important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))]
if not important_files:
return {}
paths_info = api.get_paths_info(
repo_id=repo_id,
paths=important_files,
revision=revision,
)
result = {}
for file_info in paths_info:
if hasattr(file_info, "lfs") and file_info.lfs is not None:
sha256 = file_info.lfs.sha256
else:
sha256 = getattr(file_info, "blob_id", None)
result[file_info.path] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
# 恢复原始环境变量
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:
"""从 ModelScope 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
timeout: 网络请求超时时间None 表示不设置超时
"""
import socket
from modelscope.hub.api import HubApi
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
try:
api = HubApi()
files_info = api.get_model_files(model_id=repo_id, revision=revision)
result = {}
for file_info in files_info:
filename = file_info.get("Name", file_info.get("Path", ""))
# 筛选重要文件:*.safetensors, *.json, *.py
if filename.endswith((".safetensors", ".json", ".py")):
sha256 = file_info.get("Sha256", file_info.get("sha256", None))
result[filename] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
def verify_model_integrity_with_progress(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
verbose: bool = False,
use_mirror: bool = False,
files_to_verify: list[str] | None = None,
timeout: int | None = None,
) -> Dict[str, Any]:
"""
Verify model integrity with enhanced progress reporting for Rich Progress bars.
This is a wrapper around verify_model_integrity() that provides more detailed
progress information suitable for progress bar display.
The progress_callback receives:
- (message: str, total: int, current: int) for countable operations
- (message: str) for status updates
Args:
repo_type: Repository type ("huggingface" or "modelscope")
repo_id: Repository ID
local_dir: Local directory path
progress_callback: Optional callback for progress updates
verbose: If True, output detailed SHA256 comparison for each file
use_mirror: If True, use HuggingFace mirror (hf-mirror.com)
files_to_verify: Optional list of specific files to verify (for re-verification)
timeout: Network request timeout in seconds (None = no timeout)
"""
def report_progress(msg: str, total=None, current=None):
"""Enhanced progress reporter"""
if progress_callback:
progress_callback(msg, total, current)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes
if files_to_verify:
report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...")
elif use_mirror and platform == "hf":
report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...")
else:
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)
# Filter to only requested files if specified
if files_to_verify:
# Extract clean filenames from files_to_verify (remove markers like "(missing)")
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Filter official_hashes to only include requested files
# Compare using basename since official_hashes keys might have paths
official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No safetensors files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes
local_dir_path = Path(local_dir)
# Only hash the files we need to verify
if files_to_verify:
# Extract clean filenames (without markers)
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Only hash files that match the clean filenames
files_to_hash = [
f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames
]
else:
files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()]
total_files = len(files_to_hash)
if files_to_verify:
report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0)
else:
report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0)
# Progress wrapper for hashing
completed_count = [0] # Use list for mutable closure
def hash_progress_callback(msg: str):
if "Using" in msg and "workers" in msg:
report_progress(msg)
elif "[" in msg and "/" in msg and "]" in msg:
# Progress update like: [1/10] ✓ filename (123.4 MB)
completed_count[0] += 1
report_progress(msg, total=total_files, current=completed_count[0])
# Pass the pre-filtered files_to_hash list
local_hashes = calculate_local_sha256(
local_dir_path,
"*.safetensors",
progress_callback=hash_progress_callback,
files_list=files_to_hash if files_to_verify else None,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes
report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0)
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
file_basename = Path(filename).name
# Find matching local file
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: <missing>",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)",
total=len(official_hashes),
current=idx,
)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)",
total=len(official_hashes),
current=idx,
)
else:
files_passed += 1
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx
)
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
TimeoutError, # Socket timeout from socket.setdefaulttimeout()
OSError, # Network-related OS errors
) as e:
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None:
"""Pre-operation verification of model integrity.
Can be used before running or quantizing models to ensure integrity.
Args:
user_model: UserModel object to verify
user_registry: UserModelRegistry instance
operation_name: Name of the operation (e.g., "running", "quantizing")
"""
from rich.prompt import Prompt, Confirm
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from kt_kernel.cli.i18n import get_lang
from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step
import typer
lang = get_lang()
# Check if already verified
if user_model.sha256_status == "passed":
console.print()
print_info("Model integrity already verified ✓")
console.print()
return
# Model not verified yet
console.print()
console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]")
console.print()
# Check if repo_id exists
if not user_model.repo_id:
# No repo_id - ask user to provide one
console.print("[yellow]No repository ID configured for this model.[/yellow]")
console.print()
console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')")
console.print()
if not Confirm.ask("Would you like to configure repository ID now?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Ask for repo type
console.print()
console.print("Repository type:")
console.print(" [cyan][1][/cyan] HuggingFace")
console.print(" [cyan][2][/cyan] ModelScope")
console.print()
repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1")
repo_type = "huggingface" if repo_type_choice == "1" else "modelscope"
# Ask for repo_id
console.print()
repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)")
# Update model
user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id})
user_model.repo_type = repo_type
user_model.repo_id = repo_id
console.print()
print_success(f"Repository configured: {repo_type}:{repo_id}")
console.print()
# Now ask if user wants to verify
console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]")
console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]")
console.print()
if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Perform verification
console.print()
print_step("Verifying model integrity...")
console.print()
# Check connectivity first
use_mirror = False
if user_model.repo_type == "huggingface":
with console.status("[dim]Checking HuggingFace connectivity...[/dim]"):
is_accessible, message = check_huggingface_connectivity(timeout=5)
if not is_accessible:
print_warning("HuggingFace Connection Failed")
console.print()
console.print(f" {message}")
console.print()
console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]")
console.print()
use_mirror = True
# Fetch remote hashes with timeout
def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):
"""Fetch hashes with timeout."""
executor = ThreadPoolExecutor(max_workers=1)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)
hashes = future.result(timeout=timeout)
executor.shutdown(wait=False)
return (hashes, False)
except (FutureTimeoutError, Exception):
executor.shutdown(wait=False)
return (None, True)
# Try fetching hashes
status = console.status("[dim]Fetching remote hashes...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)
status.stop()
# Handle timeout with fallback
if timed_out and user_model.repo_type == "huggingface" and not use_mirror:
print_warning("HuggingFace Fetch Timeout (10s)")
console.print()
console.print(" [yellow]Trying HuggingFace mirror...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from mirror...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)
status.stop()
if timed_out and user_model.repo_type == "huggingface":
print_warning("HuggingFace Mirror Timeout (10s)")
console.print()
console.print(" [yellow]Fallback to ModelScope...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10)
status.stop()
if not official_hashes or timed_out:
print_error("Failed to fetch remote hashes (network timeout)")
console.print()
console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]")
console.print()
if not Confirm.ask(f"Continue {operation_name} without verification?", default=False):
raise typer.Exit(0)
console.print()
return
console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]")
console.print()
# Calculate local hashes and compare
local_dir = Path(user_model.path)
files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()]
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
) as progress:
# Calculate local hashes
task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash))
def hash_callback(msg):
if "[" in msg and "/" in msg and "]" in msg and "" in msg:
progress.advance(task)
local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback)
progress.remove_task(task)
console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]")
console.print()
# Compare hashes
task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes))
files_failed = []
files_missing = []
files_passed = 0
for filename, official_hash in official_hashes.items():
file_basename = Path(filename).name
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
else:
files_passed += 1
progress.advance(task)
progress.remove_task(task)
console.print()
# Check results
if not files_failed and not files_missing:
# Verification passed
user_registry.update_model(user_model.name, {"sha256_status": "passed"})
print_success("Model integrity verification PASSED ✓")
console.print()
console.print(f" All {files_passed} files verified successfully")
console.print()
else:
# Verification failed
user_registry.update_model(user_model.name, {"sha256_status": "failed"})
print_error(f"Model integrity verification FAILED")
console.print()
console.print(f" ✓ Passed: [green]{files_passed}[/green]")
console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]")
console.print()
if files_missing:
console.print(f" [red]Missing files ({len(files_missing)}):[/red]")
for f in files_missing[:5]:
console.print(f" - {Path(f).name}")
if len(files_missing) > 5:
console.print(f" ... and {len(files_missing) - 5} more")
console.print()
if files_failed:
console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]")
for f in files_failed[:5]:
console.print(f" - {f}")
if len(files_failed) > 5:
console.print(f" ... and {len(files_failed) - 5} more")
console.print()
console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]")
console.print()
console.print("This could cause runtime errors or incorrect inference results.")
console.print()
# Ask if user wants to repair
if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True):
console.print()
print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]")
console.print()
console.print("The verify command will guide you through the repair process.")
raise typer.Exit(0)
# Ask if user wants to continue anyway
console.print()
if not Confirm.ask(
f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False
):
raise typer.Exit(0)
console.print()
print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...")
console.print()

View File

@@ -0,0 +1,57 @@
"""
Port availability checking utilities.
"""
import socket
from typing import Tuple
def is_port_available(host: str, port: int) -> bool:
"""Check if a port is available on the given host.
Args:
host: Host address (e.g., "0.0.0.0", "127.0.0.1")
port: Port number to check
Returns:
True if port is available, False if occupied
"""
try:
# Try to bind to the port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
# Use SO_REUSEADDR to allow binding to recently closed ports
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Try to bind
result = sock.connect_ex((host if host != "0.0.0.0" else "127.0.0.1", port))
sock.close()
# If connect_ex returns 0, port is occupied
# If it returns error (non-zero), port is available
return result != 0
except Exception:
# If any error occurs, assume port is not available
return False
def find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]:
"""Find an available port starting from start_port.
Args:
host: Host address
start_port: Starting port number to check
max_attempts: Maximum number of ports to try
Returns:
Tuple of (found, port_number)
- found: True if an available port was found
- port_number: The available port number (or start_port if not found)
"""
for port in range(start_port, start_port + max_attempts):
if is_port_available(host, port):
return True, port
return False, start_port

View File

@@ -0,0 +1,347 @@
"""
Interactive configuration for kt quant command.
Provides rich, multi-step interactive configuration for model quantization.
"""
from typing import Optional, Dict, Any
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm, IntPrompt
from kt_kernel.cli.i18n import t
console = Console()
def select_model_to_quantize() -> Optional[Any]:
"""Select model to quantize interactively."""
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP
from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table
registry = UserModelRegistry()
all_models = registry.list_models()
# Filter MoE models only (safetensors, not AMX, is_moe=True)
quant_models = []
for model in all_models:
if model.format == "safetensors":
# Skip AMX models
is_amx, _ = is_amx_weights(model.path)
if is_amx:
continue
# Only include MoE models
if model.is_moe:
quant_models.append(model)
if not quant_models:
console.print(f"[yellow]{t('quant_no_moe_models')}[/yellow]")
console.print()
console.print(f" {t('quant_only_moe')}")
console.print()
console.print(f" {t('quant_add_models', command='kt model scan')}")
console.print(f" {t('quant_add_models', command='kt model add <path>')}")
return None
# Display models
console.print()
console.print(f"[bold green]{t('quant_moe_available')}[/bold green]")
console.print()
# Use shared table builder
table, displayed_models = build_moe_gpu_table(
models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1
)
console.print(table)
console.print()
choice = IntPrompt.ask(t("quant_select_model"), default=1, show_choices=False)
if choice < 1 or choice > len(displayed_models):
console.print(f"[red]{t('quant_invalid_choice')}[/red]")
return None
return displayed_models[choice - 1]
def configure_quantization_method() -> Dict[str, str]:
"""Select quantization method and input type."""
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step2_method')}[/bold cyan]", expand=False))
console.print()
# Method selection
console.print(f"[bold]{t('quant_method_label')}[/bold]")
console.print(f" [cyan][1][/cyan] {t('quant_int4_desc')}")
console.print(f" [cyan][2][/cyan] {t('quant_int8_desc')}")
console.print()
method_choice = Prompt.ask(t("quant_select_method"), choices=["1", "2"], default="1")
method = "int4" if method_choice == "1" else "int8"
console.print()
console.print(f"[bold]{t('quant_input_type_label')}[/bold]")
console.print(f" [cyan][1][/cyan] {t('quant_fp8_desc')}")
console.print(f" [cyan][2][/cyan] {t('quant_fp16_desc')}")
console.print(f" [cyan][3][/cyan] {t('quant_bf16_desc')}")
console.print()
input_choice = Prompt.ask(t("quant_select_input_type"), choices=["1", "2", "3"], default="1")
input_type_map = {"1": "fp8", "2": "fp16", "3": "bf16"}
input_type = input_type_map[input_choice]
return {"method": method, "input_type": input_type}
def configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]:
"""Configure CPU parameters."""
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]", expand=False))
console.print()
def clamp(value: int, min_val: int, max_val: int, default: int) -> int:
"""Clamp value to range or return default if out of bounds."""
if min_val <= value <= max_val:
return max(min_val, min(value, max_val))
return default
default_threads = int(max_cores * 0.8)
cpu_threads = IntPrompt.ask(t("quant_cpu_threads_prompt", max=max_cores), default=default_threads)
cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads)
numa_nodes = IntPrompt.ask(t("quant_numa_nodes_prompt", max=max_numa), default=max_numa)
numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa)
# Ask about GPU usage
console.print()
console.print(f"[bold]{t('quant_use_gpu_label')}[/bold]")
console.print(f" [dim]{t('quant_gpu_speedup')}[/dim]")
console.print()
use_gpu = Confirm.ask(t("quant_enable_gpu"), default=True)
return {"cpu_threads": cpu_threads, "numa_nodes": numa_nodes, "use_gpu": use_gpu}
def configure_output_path(model: Any, method: str, numa_nodes: int) -> Path:
"""Configure output path for quantized weights."""
from kt_kernel.cli.config.settings import get_settings
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step4_output')}[/bold cyan]", expand=False))
console.print()
# Generate default output path
model_path = Path(model.path)
method_upper = method.upper()
settings = get_settings()
# Priority: paths.weights > paths.models[0] > model's parent directory
weights_dir = settings.weights_dir
if weights_dir and weights_dir.exists():
# Use configured weights directory (highest priority)
default_output = weights_dir / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
else:
# Use first model storage path
model_paths = settings.get_model_paths()
if model_paths and model_paths[0].exists():
default_output = model_paths[0] / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
else:
# Fallback to model's parent directory
default_output = model_path.parent / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
console.print(f"[dim]{t('quant_default_path')}[/dim]", default_output)
console.print()
use_default = Confirm.ask(t("quant_use_default"), default=True)
if use_default:
return default_output
custom_path = Prompt.ask(t("quant_custom_path"), default=str(default_output))
return Path(custom_path)
def calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]:
"""
Calculate source model size and estimated quantized size.
Args:
source_path: Path to source model
input_type: Input type (fp8, fp16, bf16)
quant_method: Quantization method (int4, int8)
Returns:
Tuple of (source_size_gb, estimated_quant_size_gb)
"""
# Calculate source model size
try:
total_bytes = sum(f.stat().st_size for f in source_path.glob("*.safetensors") if f.is_file())
source_size_gb = total_bytes / (1024**3)
except Exception:
return 0.0, 0.0
# Bits mapping
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
quant_bits = {"int4": 4, "int8": 8}
input_bit = input_bits.get(input_type, 16)
quant_bit = quant_bits.get(quant_method, 4)
# Estimate: source_size * (quant_bits / input_bits)
ratio = quant_bit / input_bit
estimated_size_gb = source_size_gb * ratio
return source_size_gb, estimated_size_gb
def check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]:
"""
Check available disk space at output path.
Args:
output_path: Target output path
required_size_gb: Required space in GB
Returns:
Tuple of (available_gb, is_sufficient)
is_sufficient is True if available >= required * 1.2
"""
import shutil
try:
# Get parent directory that exists
check_path = output_path.parent if not output_path.exists() else output_path
while not check_path.exists() and check_path != check_path.parent:
check_path = check_path.parent
stat = shutil.disk_usage(check_path)
available_gb = stat.free / (1024**3)
# Check if available space >= required * 1.2 (20% buffer)
is_sufficient = available_gb >= (required_size_gb * 1.2)
return available_gb, is_sufficient
except Exception:
return 0.0, False
def interactive_quant_config() -> Optional[Dict[str, Any]]:
"""
Interactive configuration for kt quant.
Returns configuration dict or None if cancelled.
"""
from kt_kernel.cli.utils.environment import detect_cpu_info
# Get CPU info
cpu_info = detect_cpu_info()
# Step 1: Select model
model = select_model_to_quantize()
if not model:
return None
# Step 1.5: Pre-quantization verification (optional)
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
user_registry = UserModelRegistry()
user_model_obj = user_registry.find_by_path(model.path)
if user_model_obj and user_model_obj.format == "safetensors":
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
# Step 2: Configure quantization method
quant_config = configure_quantization_method()
# Step 3: Configure CPU parameters
cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes) # Use logical threads
# Step 4: Configure output path
output_path = configure_output_path(model, quant_config["method"], cpu_config["numa_nodes"])
# Step 4.5: Check if output path already exists and generate unique name
if output_path.exists():
console.print()
console.print(t("quant_output_exists_warn", path=str(output_path)))
console.print()
# Generate unique name by adding suffix
original_name = output_path.name
parent_dir = output_path.parent
counter = 2
while output_path.exists():
new_name = f"{original_name}-{counter}"
output_path = parent_dir / new_name
counter += 1
console.print(t("quant_using_unique_name", path=str(output_path)))
console.print()
# Step 5: Calculate space requirements and check availability
console.print()
console.print(Panel(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]", expand=False))
console.print()
source_size_gb, estimated_size_gb = calculate_quantized_size(
Path(model.path), quant_config["input_type"], quant_config["method"]
)
available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb)
console.print(f" {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]")
console.print(f" {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]")
console.print(
f" {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]"
)
console.print()
if not is_sufficient:
required_with_buffer = estimated_size_gb * 1.2
console.print(f"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]")
console.print()
console.print(f" {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]")
console.print(f" {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]")
console.print(f" {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]")
console.print()
console.print(f" {t('quant_may_fail')}")
console.print()
if not Confirm.ask(f"[yellow]{t('quant_continue_anyway')}[/yellow]", default=False):
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
return None
console.print()
# Summary and confirmation
console.print()
console.print(Panel(f"[bold cyan]{t('quant_config_summary')}[/bold cyan]", expand=False))
console.print()
console.print(f" {t('quant_summary_model'):<15} {model.name}")
console.print(f" {t('quant_summary_method'):<15} {quant_config['method'].upper()}")
console.print(f" {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}")
console.print(f" {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}")
console.print(f" {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}")
console.print(f" {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}")
console.print(f" {t('quant_summary_output'):<15} {output_path}")
console.print()
if not Confirm.ask(f"[bold green]{t('quant_start_question')}[/bold green]", default=True):
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
return None
return {
"model": model,
"method": quant_config["method"],
"input_type": quant_config["input_type"],
"cpu_threads": cpu_config["cpu_threads"],
"numa_nodes": cpu_config["numa_nodes"],
"use_gpu": cpu_config["use_gpu"],
"output_path": output_path,
}

View File

@@ -0,0 +1,364 @@
"""
Repo Detector
Automatically detect repository information from model README.md files
"""
import re
from pathlib import Path
from typing import Optional, Dict, Tuple
import yaml
def parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]:
"""
Parse YAML frontmatter from README.md
Args:
readme_path: Path to README.md file
Returns:
Dictionary of frontmatter data, or None if not found
"""
if not readme_path.exists():
return None
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
# Match YAML frontmatter between --- markers
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not match:
return None
yaml_content = match.group(1)
# Parse YAML
try:
data = yaml.safe_load(yaml_content)
return data if isinstance(data, dict) else None
except yaml.YAMLError:
return None
except Exception as e:
return None
def extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]:
"""
Extract repo_id and repo_type from frontmatter
Args:
frontmatter: Parsed YAML frontmatter dictionary
Returns:
Tuple of (repo_id, repo_type) or None
repo_type is either "huggingface" or "modelscope"
"""
if not frontmatter:
return None
# Priority 1: Extract from license_link (most reliable)
license_link = frontmatter.get("license_link")
if license_link and isinstance(license_link, str):
result = _extract_repo_from_url(license_link)
if result:
return result
# Priority 2: Try to find repo_id from other fields
repo_id = None
# Check base_model field
base_model = frontmatter.get("base_model")
if base_model:
if isinstance(base_model, list) and len(base_model) > 0:
# base_model is a list, take first item
repo_id = base_model[0]
elif isinstance(base_model, str):
repo_id = base_model
# Check model-index field
if not repo_id:
model_index = frontmatter.get("model-index")
if isinstance(model_index, list) and len(model_index) > 0:
first_model = model_index[0]
if isinstance(first_model, dict):
repo_id = first_model.get("name")
# Check model_name field
if not repo_id:
repo_id = frontmatter.get("model_name")
if not repo_id or not isinstance(repo_id, str):
return None
# Validate format: should be "namespace/model-name"
if "/" not in repo_id:
return None
parts = repo_id.split("/")
if len(parts) != 2:
return None
# Determine repo type
repo_type = "huggingface" # Default
# Look for ModelScope indicators
if "modelscope" in repo_id.lower():
repo_type = "modelscope"
# Check tags
tags = frontmatter.get("tags", [])
if isinstance(tags, list):
if "modelscope" in [str(t).lower() for t in tags]:
repo_type = "modelscope"
return (repo_id, repo_type)
def _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]:
"""
Extract repo_id and repo_type from a URL
Supports:
- https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE
- https://modelscope.cn/models/Qwen/Qwen3-30B-A3B
Args:
url: URL string
Returns:
Tuple of (repo_id, repo_type) or None
"""
# HuggingFace pattern: https://huggingface.co/{namespace}/{model}/...
hf_match = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", url)
if hf_match:
namespace = hf_match.group(1)
model_name = hf_match.group(2)
repo_id = f"{namespace}/{model_name}"
return (repo_id, "huggingface")
# ModelScope pattern: https://modelscope.cn/models/{namespace}/{model}
ms_match = re.match(r"https?://(?:www\.)?modelscope\.cn/models/([^/]+)/([^/]+)", url)
if ms_match:
namespace = ms_match.group(1)
model_name = ms_match.group(2)
repo_id = f"{namespace}/{model_name}"
return (repo_id, "modelscope")
return None
def extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]:
"""
Extract repo info by globally searching for URLs in README.md
Args:
readme_path: Path to README.md file
Returns:
Tuple of (repo_id, repo_type) or None if not found
"""
if not readme_path.exists():
return None
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
# Find all HuggingFace URLs
hf_pattern = r"https?://huggingface\.co/([^/\s]+)/([^/\s\)]+)"
hf_matches = re.findall(hf_pattern, content)
# Find all ModelScope URLs
ms_pattern = r"https?://(?:www\.)?modelscope\.cn/models/([^/\s]+)/([^/\s\)]+)"
ms_matches = re.findall(ms_pattern, content)
# Collect all found repos with their types
found_repos = []
for namespace, model_name in hf_matches:
# Skip common non-repo paths
if namespace.lower() in ["docs", "blog", "spaces", "datasets"]:
continue
if model_name.lower() in ["tree", "blob", "raw", "resolve", "discussions"]:
continue
repo_id = f"{namespace}/{model_name}"
found_repos.append((repo_id, "huggingface"))
for namespace, model_name in ms_matches:
repo_id = f"{namespace}/{model_name}"
found_repos.append((repo_id, "modelscope"))
if not found_repos:
return None
# If multiple different repos found, use the last one
# First, deduplicate
seen = {}
for repo_id, repo_type in found_repos:
seen[repo_id] = repo_type # Will keep the last occurrence
# Get the last unique repo
if seen:
# Use the last item from found_repos that's unique
last_unique = None
for repo_id, repo_type in found_repos:
if repo_id in seen:
last_unique = (repo_id, repo_type)
return last_unique
return None
except Exception as e:
return None
def detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]:
"""
Detect repository information for a model
Strategy:
Only extract from YAML frontmatter metadata in README.md
(Removed global URL search to avoid false positives)
Args:
model_path: Path to model directory
Returns:
Tuple of (repo_id, repo_type) or None if not detected
"""
model_dir = Path(model_path)
if not model_dir.exists() or not model_dir.is_dir():
return None
# Look for README.md
readme_path = model_dir / "README.md"
if not readme_path.exists():
return None
# Only parse YAML frontmatter (no fallback to global search)
frontmatter = parse_readme_frontmatter(readme_path)
if frontmatter:
return extract_repo_from_frontmatter(frontmatter)
return None
def scan_models_for_repo(model_list) -> Dict:
"""
Scan a list of models and detect repo information
Args:
model_list: List of UserModel objects
Returns:
Dictionary with scan results:
{
'detected': [(model, repo_id, repo_type), ...],
'not_detected': [model, ...],
'skipped': [model, ...] # Already has repo_id
}
"""
results = {"detected": [], "not_detected": [], "skipped": []}
for model in model_list:
# Skip if already has repo_id
if model.repo_id:
results["skipped"].append(model)
continue
# Only process safetensors and gguf models
if model.format not in ["safetensors", "gguf"]:
results["skipped"].append(model)
continue
# Try to detect repo
repo_info = detect_repo_for_model(model.path)
if repo_info:
repo_id, repo_type = repo_info
results["detected"].append((model, repo_id, repo_type))
else:
results["not_detected"].append(model)
return results
def format_detection_report(results: Dict) -> str:
"""
Format scan results into a readable report
Args:
results: Results from scan_models_for_repo()
Returns:
Formatted string report
"""
lines = []
lines.append("=" * 80)
lines.append("Auto-Detection Report")
lines.append("=" * 80)
lines.append("")
# Detected
if results["detected"]:
lines.append(f"✓ Detected repository information ({len(results['detected'])} models):")
lines.append("")
for model, repo_id, repo_type in results["detected"]:
lines.append(f"{model.name}")
lines.append(f" Path: {model.path}")
lines.append(f" Repo: {repo_id} ({repo_type})")
lines.append("")
# Not detected
if results["not_detected"]:
lines.append(f"✗ No repository information found ({len(results['not_detected'])} models):")
lines.append("")
for model in results["not_detected"]:
lines.append(f"{model.name}")
lines.append(f" Path: {model.path}")
lines.append("")
# Skipped
if results["skipped"]:
lines.append(f"⊘ Skipped ({len(results['skipped'])} models):")
lines.append(f" (Already have repo_id or not safetensors/gguf format)")
lines.append("")
lines.append("=" * 80)
lines.append(
f"Summary: {len(results['detected'])} detected, "
f"{len(results['not_detected'])} not detected, "
f"{len(results['skipped'])} skipped"
)
lines.append("=" * 80)
return "\n".join(lines)
def apply_detection_results(results: Dict, registry) -> int:
"""
Apply detected repo information to models in registry
Args:
results: Results from scan_models_for_repo()
registry: UserModelRegistry instance
Returns:
Number of models updated
"""
updated_count = 0
for model, repo_id, repo_type in results["detected"]:
success = registry.update_model(model.name, {"repo_id": repo_id, "repo_type": repo_type})
if success:
updated_count += 1
return updated_count

View File

@@ -0,0 +1,111 @@
"""
Configuration save/load for kt run command.
Manages saved run configurations bound to specific models.
"""
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime
import yaml
CONFIG_FILE = Path.home() / ".ktransformers" / "run_configs.yaml"
class RunConfigManager:
"""Manager for saved run configurations."""
def __init__(self):
self.config_file = CONFIG_FILE
self._ensure_config_file()
def _ensure_config_file(self):
"""Ensure config file exists."""
if not self.config_file.exists():
self.config_file.parent.mkdir(parents=True, exist_ok=True)
self._save_data({"version": "1.0", "configs": {}})
def _load_data(self) -> Dict:
"""Load raw config data."""
try:
with open(self.config_file, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {"version": "1.0", "configs": {}}
except Exception:
return {"version": "1.0", "configs": {}}
def _save_data(self, data: Dict):
"""Save raw config data."""
with open(self.config_file, "w", encoding="utf-8") as f:
yaml.dump(data, f, allow_unicode=True, default_flow_style=False)
def list_configs(self, model_id: str) -> List[Dict[str, Any]]:
"""List all saved configs for a model.
Returns:
List of config dicts with 'config_name' and other fields.
"""
data = self._load_data()
configs = data.get("configs", {}).get(model_id, [])
return configs if isinstance(configs, list) else []
def save_config(self, model_id: str, config: Dict[str, Any]):
"""Save a configuration for a model.
Args:
model_id: Model ID to bind config to
config: Configuration dict with all run parameters
"""
data = self._load_data()
if "configs" not in data:
data["configs"] = {}
if model_id not in data["configs"]:
data["configs"][model_id] = []
# Add timestamp
config["created_at"] = datetime.now().isoformat()
# Append config
data["configs"][model_id].append(config)
self._save_data(data)
def delete_config(self, model_id: str, config_index: int) -> bool:
"""Delete a saved configuration.
Args:
model_id: Model ID
config_index: Index of config to delete (0-based)
Returns:
True if deleted, False if not found
"""
data = self._load_data()
if model_id not in data.get("configs", {}):
return False
configs = data["configs"][model_id]
if config_index < 0 or config_index >= len(configs):
return False
configs.pop(config_index)
self._save_data(data)
return True
def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]:
"""Get a specific saved configuration.
Args:
model_id: Model ID
config_index: Index of config to get (0-based)
Returns:
Config dict or None if not found
"""
configs = self.list_configs(model_id)
if config_index < 0 or config_index >= len(configs):
return None
return configs[config_index]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,459 @@
"""
Tuna engine for auto-tuning GPU experts configuration.
Automatically finds the maximum viable num-gpu-experts through binary search
by testing actual server launches with different configurations.
"""
import json
import math
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Optional
from kt_kernel.cli.utils.console import console, print_error, print_info, print_warning
def get_num_experts(model_path: Path) -> int:
"""
Get the number of experts per layer from model config.
Args:
model_path: Path to the model directory
Returns:
Number of experts per layer
Raises:
ValueError: If config.json not found or num_experts field missing
"""
config_file = model_path / "config.json"
if not config_file.exists():
raise ValueError(f"config.json not found in {model_path}")
try:
config = json.loads(config_file.read_text())
except Exception as e:
raise ValueError(f"Failed to parse config.json: {e}")
# Different models may use different field names
possible_keys = [
"num_experts_per_tok", # DeepSeek
"num_local_experts", # Mixtral
"n_routed_experts", # Qwen
"num_experts", # Generic
]
for key in possible_keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find num_experts field in {config_file}. " f"Tried: {', '.join(possible_keys)}")
def detect_oom(log_line: Optional[str]) -> bool:
"""
Detect OOM (Out Of Memory) errors from log output.
Args:
log_line: A line from server output
Returns:
True if OOM detected, False otherwise
"""
if log_line is None:
return False
log_lower = log_line.lower()
oom_patterns = [
"cuda out of memory",
"out of memory",
"outofmemoryerror",
"oom",
"failed to allocate",
"cumemalloc failed",
"cumemallocasync failed",
"allocation failed",
]
return any(pattern in log_lower for pattern in oom_patterns)
def test_config(
num_gpu_experts: int,
model_path: Path,
config: dict,
verbose: bool = False,
) -> tuple[bool, float]:
"""
Test if a configuration with given num_gpu_experts works.
Args:
num_gpu_experts: Number of GPU experts to test
model_path: Path to the model
config: Configuration dict with all parameters
verbose: Whether to show detailed logs
Returns:
(success: bool, elapsed_time: float)
- success: True if server starts and inference works
- elapsed_time: Time taken for the test
"""
start_time = time.time()
# Use random port to avoid conflicts
test_port = random.randint(30000, 40000)
# Build command
cmd = [
sys.executable,
"-m",
"sglang.launch_server",
"--model",
str(model_path),
"--port",
str(test_port),
"--host",
"127.0.0.1",
"--tensor-parallel-size",
str(config["tensor_parallel_size"]),
"--kt-num-gpu-experts",
str(num_gpu_experts),
"--max-total-tokens",
str(config["max_total_tokens"]),
]
# Add kt-kernel options
if config.get("weights_path"):
cmd.extend(["--kt-weight-path", str(config["weights_path"])])
else:
cmd.extend(["--kt-weight-path", str(model_path)])
cmd.extend(
[
"--kt-cpuinfer",
str(config.get("cpu_threads", 64)),
"--kt-threadpool-count",
str(config.get("numa_nodes", 2)),
"--kt-method",
config.get("kt_method", "AMXINT4"),
"--kt-gpu-prefill-token-threshold",
str(config.get("kt_gpu_prefill_threshold", 4096)),
]
)
# Add other SGLang options
if config.get("attention_backend"):
cmd.extend(["--attention-backend", config["attention_backend"]])
cmd.extend(
[
"--trust-remote-code",
"--mem-fraction-static",
str(config.get("mem_fraction_static", 0.98)),
"--chunked-prefill-size",
str(config.get("chunked_prefill_size", 4096)),
"--max-running-requests",
str(config.get("max_running_requests", 1)), # Use 1 for faster testing
"--watchdog-timeout",
str(config.get("watchdog_timeout", 3000)),
"--enable-mixed-chunk",
"--enable-p2p-check",
]
)
# Add disable-shared-experts-fusion if specified
if config.get("disable_shared_experts_fusion"):
cmd.append("--disable-shared-experts-fusion")
# Add extra args
if config.get("extra_args"):
cmd.extend(config["extra_args"])
if verbose:
console.print(f"[dim]Command: {' '.join(cmd)}[/dim]")
# Start process
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=config.get("env"),
)
except Exception as e:
if verbose:
print_error(f"Failed to start process: {e}")
return False, time.time() - start_time
# Monitor process output
timeout = 60 # Maximum 60 seconds to wait
server_ready = False
try:
while time.time() - start_time < timeout:
# Check if process has output
if process.poll() is not None:
# Process exited
if verbose:
print_warning("Process exited early")
return False, time.time() - start_time
# Read output line (non-blocking)
try:
line = process.stdout.readline()
if not line:
time.sleep(0.1)
continue
if verbose:
console.print(f"[dim]{line.rstrip()}[/dim]")
# Fast OOM detection
if detect_oom(line):
if verbose:
print_warning(f"OOM detected: {line.rstrip()}")
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
return False, time.time() - start_time
# Check for startup success
if "Uvicorn running" in line or "Application startup complete" in line:
server_ready = True
break
except Exception as e:
if verbose:
print_warning(f"Error reading output: {e}")
break
if not server_ready:
# Timeout or failed to start
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
return False, time.time() - start_time
# Server is ready, test inference
success = test_inference(test_port, verbose=verbose)
# Cleanup
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2)
return success, time.time() - start_time
except KeyboardInterrupt:
# User cancelled
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
raise
except Exception as e:
if verbose:
print_error(f"Test failed with exception: {e}")
try:
process.terminate()
process.wait(timeout=2)
except:
try:
process.kill()
except:
pass
return False, time.time() - start_time
def test_inference(port: int, verbose: bool = False) -> bool:
"""
Test if the server can handle a simple inference request.
Args:
port: Server port
verbose: Whether to show detailed logs
Returns:
True if inference succeeds, False otherwise
"""
try:
# Wait a bit for server to be fully ready
time.sleep(2)
# Try to import OpenAI client
try:
from openai import OpenAI
except ImportError:
if verbose:
print_warning("OpenAI package not available, skipping inference test")
return True # Assume success if we can't test
client = OpenAI(
base_url=f"http://127.0.0.1:{port}/v1",
api_key="test",
)
# Send a simple test request
response = client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hi"}],
max_tokens=1,
temperature=0,
timeout=10,
)
# Check if we got a valid response
success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None
if verbose:
if success:
print_info(f"Inference test passed: {response.choices[0].message.content}")
else:
print_warning("Inference test failed: no valid response")
return success
except Exception as e:
if verbose:
print_warning(f"Inference test failed: {e}")
return False
def find_max_gpu_experts(
model_path: Path,
config: dict,
verbose: bool = False,
) -> int:
"""
Binary search to find the maximum viable num_gpu_experts.
Args:
model_path: Path to the model
config: Configuration dict
verbose: Whether to show detailed logs
Returns:
Maximum number of GPU experts that works
"""
# Get number of experts from model config
try:
num_experts = get_num_experts(model_path)
except ValueError as e:
print_error(str(e))
raise
console.print()
console.print(f"Binary search range: [0, {num_experts}]")
console.print()
left, right = 0, num_experts
result = 0
iteration = 0
total_iterations = math.ceil(math.log2(num_experts + 1))
while left <= right:
iteration += 1
mid = (left + right) // 2
console.print(f"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... ", end="")
success, elapsed = test_config(mid, model_path, config, verbose=verbose)
if success:
console.print(f"[green]✓ OK[/green] ({elapsed:.1f}s)")
result = mid
left = mid + 1
else:
console.print(f"[red]✗ FAILED[/red] ({elapsed:.1f}s)")
right = mid - 1
return result
def run_tuna(
model_path: Path,
tensor_parallel_size: int,
max_total_tokens: int,
kt_method: str,
verbose: bool = False,
**kwargs,
) -> int:
"""
Run tuna auto-tuning to find optimal num_gpu_experts.
Args:
model_path: Path to the model
tensor_parallel_size: Tensor parallel size
max_total_tokens: Maximum total tokens
kt_method: KT quantization method
verbose: Whether to show detailed logs
**kwargs: Additional configuration parameters
Returns:
Optimal num_gpu_experts value
Raises:
ValueError: If tuning fails completely
"""
# Prepare configuration
config = {
"tensor_parallel_size": tensor_parallel_size,
"max_total_tokens": max_total_tokens,
"kt_method": kt_method,
**kwargs,
}
# Run binary search
try:
result = find_max_gpu_experts(model_path, config, verbose=verbose)
except KeyboardInterrupt:
console.print()
print_warning("Tuning cancelled by user")
raise
console.print()
# Check if even 0 doesn't work
if result == 0:
console.print("[yellow]Testing if gpu-experts=0 is viable...[/yellow]")
success, _ = test_config(0, model_path, config, verbose=verbose)
if not success:
# Even 0 doesn't work
console.print()
print_error("Failed to start server even with all experts on CPU (gpu-experts=0)")
console.print()
console.print("[bold]Possible reasons:[/bold]")
console.print(" • Insufficient GPU memory for base model layers")
console.print(" • max-total-tokens is too large for available VRAM")
console.print(" • Tensor parallel configuration issue")
console.print()
console.print("[bold]Suggestions:[/bold]")
console.print(f" • Reduce --max-total-tokens (current: {max_total_tokens})")
console.print(f" • Reduce --tensor-parallel-size (current: {tensor_parallel_size})")
console.print(" • Use more GPUs or GPUs with more VRAM")
console.print(" • Try a smaller model")
console.print()
raise ValueError("Minimum GPU memory requirements not met")
else:
# 0 works but nothing more
console.print()
print_warning("All experts will run on CPU (gpu-experts=0). " "Performance will be limited by CPU speed.")
return result

View File

@@ -0,0 +1,302 @@
"""
User Model Registry
Manages user-registered models in ~/.ktransformers/user_models.yaml
"""
from dataclasses import dataclass, asdict, field
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
import yaml
# Constants
USER_MODELS_FILE = Path.home() / ".ktransformers" / "user_models.yaml"
REGISTRY_VERSION = "1.0"
@dataclass
class UserModel:
"""Represents a user-registered model"""
name: str # User-editable name (default: folder name)
path: str # Absolute path to model directory
format: str # "safetensors" | "gguf"
id: Optional[str] = None # Unique UUID for this model (auto-generated if None)
repo_type: Optional[str] = None # "huggingface" | "modelscope" | None
repo_id: Optional[str] = None # e.g., "deepseek-ai/DeepSeek-V3"
sha256_status: str = "not_checked" # "not_checked" | "checking" | "passed" | "failed" | "no_repo"
gpu_model_ids: Optional[List[str]] = None # For llamafile/AMX: list of GPU model UUIDs to run with
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
last_verified: Optional[str] = None # ISO format datetime
# MoE information (cached from analyze_moe_model)
is_moe: Optional[bool] = None # True if MoE model, False if non-MoE, None if not analyzed
moe_num_experts: Optional[int] = None # Total number of experts (for MoE models)
moe_num_experts_per_tok: Optional[int] = None # Number of active experts per token (for MoE models)
# AMX quantization metadata (for format == "amx")
amx_source_model: Optional[str] = None # Name of the source MoE model that was quantized
amx_quant_method: Optional[str] = None # "int4" | "int8"
amx_numa_nodes: Optional[int] = None # Number of NUMA nodes used for quantization
def __post_init__(self):
"""Ensure ID is set after initialization"""
if self.id is None:
import uuid
self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for YAML serialization"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "UserModel":
"""Create from dictionary loaded from YAML"""
return cls(**data)
def path_exists(self) -> bool:
"""Check if model path still exists"""
return Path(self.path).exists()
class UserModelRegistry:
"""Manages the user model registry"""
def __init__(self, registry_file: Optional[Path] = None):
"""
Initialize the registry
Args:
registry_file: Path to the registry YAML file (default: USER_MODELS_FILE)
"""
self.registry_file = registry_file or USER_MODELS_FILE
self.models: List[UserModel] = []
self.version = REGISTRY_VERSION
# Ensure directory exists
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
# Load existing registry
self.load()
def load(self) -> None:
"""Load models from YAML file"""
if not self.registry_file.exists():
# Initialize empty registry
self.models = []
self.save() # Create the file
return
try:
with open(self.registry_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not data:
self.models = []
return
# Load version
self.version = data.get("version", REGISTRY_VERSION)
# Load models
models_data = data.get("models", [])
self.models = [UserModel.from_dict(m) for m in models_data]
# Migrate: ensure all models have UUIDs (for backward compatibility)
needs_save = False
for model in self.models:
if model.id is None:
import uuid
model.id = str(uuid.uuid4())
needs_save = True
if needs_save:
self.save()
except Exception as e:
raise RuntimeError(f"Failed to load user model registry: {e}")
def save(self) -> None:
"""Save models to YAML file"""
data = {"version": self.version, "models": [m.to_dict() for m in self.models]}
try:
with open(self.registry_file, "w", encoding="utf-8") as f:
yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
except Exception as e:
raise RuntimeError(f"Failed to save user model registry: {e}")
def add_model(self, model: UserModel) -> None:
"""
Add a model to the registry
Args:
model: UserModel instance to add
Raises:
ValueError: If a model with the same name already exists
"""
if self.check_name_conflict(model.name):
raise ValueError(f"Model with name '{model.name}' already exists")
self.models.append(model)
self.save()
def remove_model(self, name: str) -> bool:
"""
Remove a model from the registry
Args:
name: Name of the model to remove
Returns:
True if model was removed, False if not found
"""
original_count = len(self.models)
self.models = [m for m in self.models if m.name != name]
if len(self.models) < original_count:
self.save()
return True
return False
def update_model(self, name: str, updates: Dict[str, Any]) -> bool:
"""
Update a model's attributes
Args:
name: Name of the model to update
updates: Dictionary of attributes to update
Returns:
True if model was updated, False if not found
"""
model = self.get_model(name)
if not model:
return False
# Update attributes
for key, value in updates.items():
if hasattr(model, key):
setattr(model, key, value)
self.save()
return True
def get_model(self, name: str) -> Optional[UserModel]:
"""
Get a model by name
Args:
name: Name of the model
Returns:
UserModel instance or None if not found
"""
for model in self.models:
if model.name == name:
return model
return None
def get_model_by_id(self, model_id: str) -> Optional[UserModel]:
"""
Get a model by its unique ID
Args:
model_id: UUID of the model
Returns:
UserModel instance or None if not found
"""
for model in self.models:
if model.id == model_id:
return model
return None
def list_models(self) -> List[UserModel]:
"""
List all models
Returns:
List of all UserModel instances
"""
return self.models.copy()
def find_by_path(self, path: str) -> Optional[UserModel]:
"""
Find a model by its path
Args:
path: Model directory path
Returns:
UserModel instance or None if not found
"""
# Normalize paths for comparison
search_path = str(Path(path).resolve())
for model in self.models:
model_path = str(Path(model.path).resolve())
if model_path == search_path:
return model
return None
def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool:
"""
Check if a name conflicts with existing models
Args:
name: Name to check
exclude_name: Optional name to exclude from check (for rename operations)
Returns:
True if conflict exists, False otherwise
"""
for model in self.models:
if model.name == name and model.name != exclude_name:
return True
return False
def refresh_status(self) -> Dict[str, List[str]]:
"""
Check all models and identify missing ones
Returns:
Dictionary with 'valid' and 'missing' lists of model names
"""
valid = []
missing = []
for model in self.models:
if model.path_exists():
valid.append(model.name)
else:
missing.append(model.name)
return {"valid": valid, "missing": missing}
def get_model_count(self) -> int:
"""Get total number of registered models"""
return len(self.models)
def suggest_name(self, base_name: str) -> str:
"""
Suggest a unique name based on base_name
Args:
base_name: Base name to derive from
Returns:
A unique name (may have suffix like -2, -3 etc.)
"""
if not self.check_name_conflict(base_name):
return base_name
counter = 2
while True:
candidate = f"{base_name}-{counter}"
if not self.check_name_conflict(candidate):
return candidate
counter += 1