mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-19 22:09:10 +00:00
kt-cli enhancement (#1834)
* [feat]: redesign kt run interactive configuration with i18n support - Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port) - Add configuration save/load system (~/.ktransformers/run_configs.yaml) - Add i18n support for kt chat (en/zh translations) - Add universal input validators with auto-retry and Chinese comma support - Add port availability checker with auto-suggestion - Add parser configuration (--tool-call-parser, --reasoning-parser) - Remove tuna command and clean up redundant files - Fix: variable reference bug in run.py, filter to show only MoE models * [feat]: unify model selection UI and enable shared experts fusion by default - Unify kt run model selection table with kt model list display * Add Total size, MoE Size, Repo, and SHA256 status columns * Use consistent formatting and styling * Improve user decision-making with more information - Enable --disable-shared-experts-fusion by default * Change default value from False to True * Users can still override with --enable-shared-experts-fusion * [feat]: improve kt chat with performance metrics and better CJK support - Add performance metrics display after each response * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token) * Accurate input/output token counts using model tokenizer * Fallback to estimation if tokenizer unavailable * Metrics shown in dim style (not prominent) - Fix Chinese character input issues * Replace Prompt.ask() with console.input() for better CJK support * Fixes backspace deletion showing half-characters - Suppress NumPy subnormal warnings * Filter "The value of the smallest subnormal" warnings * Cleaner CLI output on certain hardware environments * [fix]: correct TTFT measurement in kt chat - Move start_time initialization before API call - Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms - Now correctly measures time from request sent to first token received * [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案 * [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力 * [docs]: 添加 Clawdbot 飞书接入教程链接 * [feat]: improve CLI table display, model verification, and chat experience - Add sequence number (#) column to all model tables by default - Filter kt edit to show only MoE GPU models (exclude AMX) - Extend kt model verify to check *.json and *.py files in addition to weights - Fix re-verification bug where repaired files caused false failures - Suppress tokenizer debug output in kt chat token counting * [fix]: fix cpu cores. --------- Co-authored-by: skqliao <skqliao@gmail.com>
This commit is contained in:
413
kt-kernel/python/cli/utils/analyze_moe_model.py
Normal file
413
kt-kernel/python/cli/utils/analyze_moe_model.py
Normal file
@@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速分析 MoE 模型 - 基于 config.json
|
||||
(复用 sglang 的模型注册表和判断逻辑)
|
||||
"""
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def _get_sglang_moe_architectures():
|
||||
"""
|
||||
从 sglang 的模型注册表获取所有 MoE 架构
|
||||
|
||||
复用 sglang 的代码,这样 sglang 更新后自动支持新模型
|
||||
"""
|
||||
try:
|
||||
import sys
|
||||
|
||||
# 添加 sglang 路径到 sys.path
|
||||
sglang_path = Path("/mnt/data2/ljq/sglang/python")
|
||||
if sglang_path.exists() and str(sglang_path) not in sys.path:
|
||||
sys.path.insert(0, str(sglang_path))
|
||||
|
||||
# 直接导入 sglang 的 ModelRegistry
|
||||
# 注意:这需要 sglang 及其依赖正确安装
|
||||
from sglang.srt.models.registry import ModelRegistry
|
||||
|
||||
# 获取所有支持的架构
|
||||
supported_archs = ModelRegistry.get_supported_archs()
|
||||
|
||||
# 过滤出 MoE 模型(名称包含 Moe)
|
||||
moe_archs = {arch for arch in supported_archs if "Moe" in arch or "moe" in arch.lower()}
|
||||
|
||||
# 手动添加一些不带 "Moe" 字样但是 MoE 模型的架构
|
||||
# DeepSeek V2/V3 系列
|
||||
deepseek_moe = {arch for arch in supported_archs if arch.startswith("Deepseek") or arch.startswith("deepseek")}
|
||||
moe_archs.update(deepseek_moe)
|
||||
|
||||
# DBRX 也是 MoE 模型
|
||||
dbrx_moe = {arch for arch in supported_archs if "DBRX" in arch or "dbrx" in arch.lower()}
|
||||
moe_archs.update(dbrx_moe)
|
||||
|
||||
# Grok 也是 MoE 模型
|
||||
grok_moe = {arch for arch in supported_archs if "Grok" in arch or "grok" in arch.lower()}
|
||||
moe_archs.update(grok_moe)
|
||||
|
||||
return moe_archs
|
||||
except Exception as e:
|
||||
# 如果 sglang 不可用,返回空集合
|
||||
# 这种情况下,后续会使用配置文件中的其他判断方法
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.")
|
||||
return set()
|
||||
|
||||
|
||||
# 获取 MoE 架构列表(优先从 sglang 获取)
|
||||
MOE_ARCHITECTURES = _get_sglang_moe_architectures()
|
||||
|
||||
|
||||
def _get_cache_file():
|
||||
"""获取集中式缓存文件路径"""
|
||||
cache_dir = Path.home() / ".ktransformers" / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "moe_analysis_v2.json"
|
||||
|
||||
|
||||
def _load_all_cache():
|
||||
"""加载所有缓存数据"""
|
||||
cache_file = _get_cache_file()
|
||||
if not cache_file.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_all_cache(cache_data):
|
||||
"""保存所有缓存数据"""
|
||||
cache_file = _get_cache_file()
|
||||
try:
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(cache_data, f, indent=2)
|
||||
except Exception as e:
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to save MoE cache: {e}")
|
||||
|
||||
|
||||
def _compute_config_fingerprint(config_path: Path) -> Optional[str]:
|
||||
"""计算 config.json 指纹"""
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
stat = config_path.stat()
|
||||
# 使用文件大小和修改时间作为指纹
|
||||
fingerprint_str = f"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}"
|
||||
return hashlib.md5(fingerprint_str.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_cache(model_path: Path) -> Optional[Dict[str, Any]]:
|
||||
"""加载指定模型的缓存"""
|
||||
model_path_str = str(model_path.resolve())
|
||||
all_cache = _load_all_cache()
|
||||
|
||||
if model_path_str not in all_cache:
|
||||
return None
|
||||
|
||||
try:
|
||||
cache_entry = all_cache[model_path_str]
|
||||
|
||||
# 验证缓存版本
|
||||
cache_version = cache_entry.get("cache_version", 0)
|
||||
if cache_version != 2:
|
||||
return None
|
||||
|
||||
# 验证 config.json 指纹
|
||||
config_path = model_path / "config.json"
|
||||
current_fingerprint = _compute_config_fingerprint(config_path)
|
||||
if cache_entry.get("fingerprint") != current_fingerprint:
|
||||
return None
|
||||
|
||||
return cache_entry.get("result")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _save_cache(model_path: Path, result: Dict[str, Any]):
|
||||
"""保存指定模型的缓存"""
|
||||
model_path_str = str(model_path.resolve())
|
||||
|
||||
try:
|
||||
config_path = model_path / "config.json"
|
||||
fingerprint = _compute_config_fingerprint(config_path)
|
||||
|
||||
all_cache = _load_all_cache()
|
||||
|
||||
all_cache[model_path_str] = {
|
||||
"fingerprint": fingerprint,
|
||||
"result": result,
|
||||
"cache_version": 2,
|
||||
"last_updated": __import__("datetime").datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
_save_all_cache(all_cache)
|
||||
except Exception as e:
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to save MoE cache for {model_path}: {e}")
|
||||
|
||||
|
||||
def _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]:
|
||||
"""读取 config.json 文件
|
||||
|
||||
参考 sglang 的 get_config() 实现
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _is_moe_model(config: Dict[str, Any]) -> bool:
|
||||
"""判断是否是 MoE 模型
|
||||
|
||||
参考 sglang 的模型注册表和架构识别方式
|
||||
"""
|
||||
# 方法1: 检查架构名称
|
||||
architectures = config.get("architectures", [])
|
||||
if any(arch in MOE_ARCHITECTURES for arch in architectures):
|
||||
return True
|
||||
|
||||
# 方法2: 检查是否有 MoE 相关字段(Mistral 格式)
|
||||
if config.get("moe"):
|
||||
return True
|
||||
|
||||
# 方法3: 检查是否有 num_experts 或其变体字段
|
||||
# 需要检查 text_config(对于某些多模态模型)
|
||||
text_config = config.get("text_config", config)
|
||||
|
||||
# 检查各种专家数量字段
|
||||
if (
|
||||
text_config.get("num_experts") or text_config.get("num_local_experts") or text_config.get("n_routed_experts")
|
||||
): # Kimi-K2 使用这个字段
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从 config 中提取 MoE 参数
|
||||
|
||||
参考 sglang 的各种 MoE 模型实现
|
||||
"""
|
||||
# 处理嵌套的 text_config
|
||||
text_config = config.get("text_config", config)
|
||||
|
||||
# 提取基本参数
|
||||
result = {
|
||||
"architectures": config.get("architectures", []),
|
||||
"model_type": config.get("model_type", "unknown"),
|
||||
}
|
||||
|
||||
# 专家数量(不同模型字段名不同)
|
||||
num_experts = (
|
||||
text_config.get("num_experts") # Qwen2/3 MoE, DeepSeek V2
|
||||
or text_config.get("num_local_experts") # Mixtral
|
||||
or text_config.get("n_routed_experts") # Kimi-K2, DeepSeek V3
|
||||
or config.get("moe", {}).get("num_experts") # Mistral 格式
|
||||
)
|
||||
|
||||
# 每个 token 激活的专家数
|
||||
num_experts_per_tok = (
|
||||
text_config.get("num_experts_per_tok")
|
||||
or text_config.get("num_experts_per_token")
|
||||
or config.get("moe", {}).get("num_experts_per_tok")
|
||||
or 2 # 默认值
|
||||
)
|
||||
|
||||
# 层数
|
||||
num_hidden_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or 0
|
||||
|
||||
# 隐藏层维度
|
||||
hidden_size = text_config.get("hidden_size") or text_config.get("d_model") or 0
|
||||
|
||||
# MoE 专家中间层大小
|
||||
moe_intermediate_size = (
|
||||
text_config.get("moe_intermediate_size")
|
||||
or text_config.get("intermediate_size") # 如果没有特殊的 moe_intermediate_size
|
||||
or 0
|
||||
)
|
||||
|
||||
# 共享专家中间层大小(Qwen2/3 MoE)
|
||||
shared_expert_intermediate_size = text_config.get("shared_expert_intermediate_size", 0)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"num_experts": num_experts or 0,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"hidden_size": hidden_size,
|
||||
"moe_intermediate_size": moe_intermediate_size,
|
||||
"shared_expert_intermediate_size": shared_expert_intermediate_size,
|
||||
}
|
||||
)
|
||||
|
||||
# 提取其他有用的参数
|
||||
result["num_attention_heads"] = text_config.get("num_attention_heads", 0)
|
||||
result["num_key_value_heads"] = text_config.get("num_key_value_heads", 0)
|
||||
result["vocab_size"] = text_config.get("vocab_size", 0)
|
||||
result["max_position_embeddings"] = text_config.get("max_position_embeddings", 0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _estimate_model_size(model_path: Path) -> float:
|
||||
"""估算模型总大小(GB)
|
||||
|
||||
快速统计 safetensors 文件总大小
|
||||
"""
|
||||
try:
|
||||
total_size = 0
|
||||
for file_path in model_path.glob("*.safetensors"):
|
||||
total_size += file_path.stat().st_size
|
||||
return total_size / (1024**3)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def analyze_moe_model(model_path, use_cache=True):
|
||||
"""
|
||||
快速分析 MoE 模型 - 只读取 config.json
|
||||
|
||||
参数:
|
||||
model_path: 模型路径(字符串或Path对象)
|
||||
use_cache: 是否使用缓存(默认True)
|
||||
|
||||
返回:
|
||||
dict: {
|
||||
'is_moe': 是否是 MoE 模型,
|
||||
'num_experts': 专家总数,
|
||||
'num_experts_per_tok': 每个 token 激活的专家数,
|
||||
'num_hidden_layers': 层数,
|
||||
'hidden_size': 隐藏层维度,
|
||||
'moe_intermediate_size': MoE 专家中间层大小,
|
||||
'shared_expert_intermediate_size': 共享专家中间层大小,
|
||||
'architectures': 模型架构列表,
|
||||
'model_type': 模型类型,
|
||||
'total_size_gb': 模型总大小(估算,GB),
|
||||
'cached': 是否从缓存读取
|
||||
}
|
||||
如果不是 MoE 模型或失败,返回 None
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not model_path.exists():
|
||||
return None
|
||||
|
||||
# 尝试加载缓存
|
||||
if use_cache:
|
||||
cached_result = _load_cache(model_path)
|
||||
if cached_result:
|
||||
cached_result["cached"] = True
|
||||
return cached_result
|
||||
|
||||
# 读取 config.json
|
||||
config = _load_config_json(model_path)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
# 判断是否是 MoE 模型
|
||||
if not _is_moe_model(config):
|
||||
return None
|
||||
|
||||
# 提取 MoE 参数
|
||||
params = _extract_moe_params(config)
|
||||
|
||||
# 验证必要参数
|
||||
if params["num_experts"] == 0:
|
||||
return None
|
||||
|
||||
# 估算模型大小
|
||||
total_size_gb = _estimate_model_size(model_path)
|
||||
|
||||
# 组装结果
|
||||
result = {
|
||||
"is_moe": True,
|
||||
"num_experts": params["num_experts"],
|
||||
"num_experts_per_tok": params["num_experts_per_tok"],
|
||||
"num_hidden_layers": params["num_hidden_layers"],
|
||||
"hidden_size": params["hidden_size"],
|
||||
"moe_intermediate_size": params["moe_intermediate_size"],
|
||||
"shared_expert_intermediate_size": params["shared_expert_intermediate_size"],
|
||||
"architectures": params["architectures"],
|
||||
"model_type": params["model_type"],
|
||||
"total_size_gb": total_size_gb,
|
||||
"cached": False,
|
||||
# 额外参数
|
||||
"num_attention_heads": params.get("num_attention_heads", 0),
|
||||
"num_key_value_heads": params.get("num_key_value_heads", 0),
|
||||
"vocab_size": params.get("vocab_size", 0),
|
||||
}
|
||||
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
_save_cache(model_path, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_analysis(model_path):
|
||||
"""打印模型分析结果"""
|
||||
print(f"分析模型: {model_path}\n")
|
||||
|
||||
result = analyze_moe_model(model_path)
|
||||
|
||||
if result is None:
|
||||
print("不是 MoE 模型或分析失败")
|
||||
return
|
||||
|
||||
print("=" * 70)
|
||||
print("MoE 模型分析结果")
|
||||
if result.get("cached"):
|
||||
print("[使用缓存]")
|
||||
print("=" * 70)
|
||||
print(f"模型架构:")
|
||||
print(f" - 架构: {', '.join(result['architectures'])}")
|
||||
print(f" - 类型: {result['model_type']}")
|
||||
print()
|
||||
print(f"MoE 结构:")
|
||||
print(f" - 专家总数: {result['num_experts']}")
|
||||
print(f" - 激活专家数: {result['num_experts_per_tok']} experts/token")
|
||||
print(f" - 层数: {result['num_hidden_layers']}")
|
||||
print(f" - 隐藏维度: {result['hidden_size']}")
|
||||
print(f" - MoE 中间层: {result['moe_intermediate_size']}")
|
||||
if result["shared_expert_intermediate_size"] > 0:
|
||||
print(f" - 共享专家中间层: {result['shared_expert_intermediate_size']}")
|
||||
print()
|
||||
print(f"大小统计:")
|
||||
print(f" - 模型总大小: {result['total_size_gb']:.2f} GB")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
models = ["/mnt/data2/models/Qwen3-30B-A3B", "/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507"]
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
models = [sys.argv[1]]
|
||||
|
||||
for model_path in models:
|
||||
print_analysis(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
kt-kernel/python/cli/utils/debug_configs.py
Normal file
118
kt-kernel/python/cli/utils/debug_configs.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Debug utility to inspect saved run configurations.
|
||||
|
||||
Usage: python -m kt_kernel.cli.utils.debug_configs
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def main():
|
||||
"""Show all saved configurations."""
|
||||
config_file = Path.home() / ".ktransformers" / "run_configs.yaml"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Configuration file:[/bold] {config_file}")
|
||||
console.print()
|
||||
|
||||
if not config_file.exists():
|
||||
console.print("[red]✗ Configuration file does not exist![/red]")
|
||||
console.print()
|
||||
console.print("No configurations have been saved yet.")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Failed to load configuration file: {e}[/red]")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓[/green] Configuration file loaded")
|
||||
console.print()
|
||||
|
||||
configs = data.get("configs", {})
|
||||
|
||||
if not configs:
|
||||
console.print("[yellow]No saved configurations found.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[bold]Found configurations for {len(configs)} model(s):[/bold]")
|
||||
console.print()
|
||||
|
||||
for model_id, model_configs in configs.items():
|
||||
console.print(f"[cyan]Model ID:[/cyan] {model_id}")
|
||||
console.print(f"[dim] {len(model_configs)} configuration(s)[/dim]")
|
||||
console.print()
|
||||
|
||||
if not model_configs:
|
||||
continue
|
||||
|
||||
# Display configs in a table
|
||||
table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan")
|
||||
table.add_column("#", justify="right", style="cyan")
|
||||
table.add_column("Name", style="white")
|
||||
table.add_column("Method", style="yellow")
|
||||
table.add_column("TP", justify="right", style="green")
|
||||
table.add_column("GPU Experts", justify="right", style="magenta")
|
||||
table.add_column("Created", style="dim")
|
||||
|
||||
for i, cfg in enumerate(model_configs, 1):
|
||||
method = cfg.get("inference_method", "?")
|
||||
kt_method = cfg.get("kt_method", "?")
|
||||
method_display = f"{method.upper()}"
|
||||
if method == "raw":
|
||||
method_display += f" ({cfg.get('raw_method', '?')})"
|
||||
elif method == "amx":
|
||||
method_display += f" ({kt_method})"
|
||||
|
||||
table.add_row(
|
||||
str(i),
|
||||
cfg.get("config_name", f"Config {i}"),
|
||||
method_display,
|
||||
str(cfg.get("tp_size", "?")),
|
||||
str(cfg.get("gpu_experts", "?")),
|
||||
cfg.get("created_at", "Unknown")[:19] if cfg.get("created_at") else "Unknown",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Also check user_models.yaml to show model names
|
||||
console.print("[bold]Checking model registry...[/bold]")
|
||||
console.print()
|
||||
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
try:
|
||||
registry = UserModelRegistry()
|
||||
all_models = registry.list_models()
|
||||
|
||||
console.print(f"[green]✓[/green] Found {len(all_models)} registered model(s)")
|
||||
console.print()
|
||||
|
||||
# Map model IDs to names
|
||||
id_to_name = {m.id: m.name for m in all_models}
|
||||
|
||||
console.print("[bold]Model ID → Name mapping:[/bold]")
|
||||
console.print()
|
||||
|
||||
for model_id in configs.keys():
|
||||
model_name = id_to_name.get(model_id, "[red]Unknown (model not found in registry)[/red]")
|
||||
console.print(f" {model_id[:8]}... → {model_name}")
|
||||
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]⚠ Could not load model registry: {e}[/yellow]")
|
||||
console.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
kt-kernel/python/cli/utils/download_helper.py
Normal file
146
kt-kernel/python/cli/utils/download_helper.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Helper functions for interactive model download."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
import fnmatch
|
||||
|
||||
|
||||
def list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]:
|
||||
"""
|
||||
List files in a HuggingFace repository.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: 'path', 'size' (in bytes)
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
import os
|
||||
|
||||
# Set mirror if needed
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
files_info = api.list_repo_tree(repo_id=repo_id, recursive=True)
|
||||
|
||||
result = []
|
||||
for item in files_info:
|
||||
# Skip directories
|
||||
if hasattr(item, "type") and item.type == "directory":
|
||||
continue
|
||||
|
||||
# Get file info
|
||||
file_path = item.path if hasattr(item, "path") else str(item)
|
||||
file_size = item.size if hasattr(item, "size") else 0
|
||||
|
||||
result.append({"path": file_path, "size": file_size})
|
||||
|
||||
return result
|
||||
finally:
|
||||
# Restore original endpoint
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
|
||||
def list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]:
|
||||
"""
|
||||
List files in a ModelScope repository.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: 'path', 'size' (in bytes)
|
||||
"""
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
api = HubApi()
|
||||
files_info = api.get_model_files(model_id=repo_id, recursive=True)
|
||||
|
||||
result = []
|
||||
for file_info in files_info:
|
||||
file_path = file_info.get("Name", file_info.get("Path", ""))
|
||||
file_size = file_info.get("Size", 0)
|
||||
|
||||
result.append({"path": file_path, "size": file_size})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]:
|
||||
"""Filter files by glob pattern."""
|
||||
if pattern == "*":
|
||||
return files
|
||||
|
||||
filtered = []
|
||||
for file in files:
|
||||
# Check if filename matches pattern
|
||||
filename = Path(file["path"]).name
|
||||
full_path = file["path"]
|
||||
|
||||
if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern):
|
||||
filtered.append(file)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def calculate_total_size(files: List[Dict[str, any]]) -> int:
|
||||
"""Calculate total size of files in bytes."""
|
||||
return sum(f["size"] for f in files)
|
||||
|
||||
|
||||
def format_file_list_table(files: List[Dict[str, any]], max_display: int = 10):
|
||||
"""Format file list as a table for display."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_scanner import format_size
|
||||
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("File", style="cyan", overflow="fold")
|
||||
table.add_column("Size", justify="right")
|
||||
|
||||
# Show first max_display files
|
||||
for file in files[:max_display]:
|
||||
table.add_row(file["path"], format_size(file["size"]))
|
||||
|
||||
if len(files) > max_display:
|
||||
table.add_row(f"... and {len(files) - max_display} more files", "[dim]...[/dim]")
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify if a repository exists.
|
||||
|
||||
Returns:
|
||||
(exists: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
if repo_type == "huggingface":
|
||||
import os
|
||||
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
api.repo_info(repo_id=repo_id, repo_type="model")
|
||||
return True, "Repository found"
|
||||
finally:
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
else: # modelscope
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
api = HubApi()
|
||||
api.get_model(model_id=repo_id)
|
||||
return True, "Repository found"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Repository not found: {str(e)}"
|
||||
216
kt-kernel/python/cli/utils/input_validators.py
Normal file
216
kt-kernel/python/cli/utils/input_validators.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Input validation utilities with retry mechanism.
|
||||
|
||||
Provides robust input validation with automatic retry on failure.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Callable, Any
|
||||
from rich.console import Console
|
||||
from rich.prompt import Prompt
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def prompt_int_with_retry(
|
||||
message: str,
|
||||
default: Optional[int] = None,
|
||||
min_val: Optional[int] = None,
|
||||
max_val: Optional[int] = None,
|
||||
validator: Optional[Callable[[int], bool]] = None,
|
||||
validator_error_msg: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Prompt for integer input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value (optional)
|
||||
min_val: Minimum allowed value (optional)
|
||||
max_val: Maximum allowed value (optional)
|
||||
validator: Custom validation function (optional)
|
||||
validator_error_msg: Error message for custom validator (optional)
|
||||
|
||||
Returns:
|
||||
Validated integer value
|
||||
"""
|
||||
while True:
|
||||
# Build prompt with default
|
||||
if default is not None:
|
||||
prompt_text = f"{message} [{default}]"
|
||||
else:
|
||||
prompt_text = message
|
||||
|
||||
# Get input
|
||||
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
|
||||
|
||||
# Try to parse as integer
|
||||
try:
|
||||
value = int(user_input)
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid input. Please enter a valid integer.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate range
|
||||
if min_val is not None and value < min_val:
|
||||
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
if max_val is not None and value > max_val:
|
||||
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Custom validation
|
||||
if validator is not None:
|
||||
if not validator(value):
|
||||
error_msg = validator_error_msg or "Invalid value"
|
||||
console.print(f"[red]✗ {error_msg}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return value
|
||||
|
||||
|
||||
def prompt_float_with_retry(
|
||||
message: str,
|
||||
default: Optional[float] = None,
|
||||
min_val: Optional[float] = None,
|
||||
max_val: Optional[float] = None,
|
||||
) -> float:
|
||||
"""Prompt for float input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value (optional)
|
||||
min_val: Minimum allowed value (optional)
|
||||
max_val: Maximum allowed value (optional)
|
||||
|
||||
Returns:
|
||||
Validated float value
|
||||
"""
|
||||
while True:
|
||||
# Build prompt with default
|
||||
if default is not None:
|
||||
prompt_text = f"{message} [{default}]"
|
||||
else:
|
||||
prompt_text = message
|
||||
|
||||
# Get input
|
||||
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
|
||||
|
||||
# Try to parse as float
|
||||
try:
|
||||
value = float(user_input)
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid input. Please enter a valid number.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate range
|
||||
if min_val is not None and value < min_val:
|
||||
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
if max_val is not None and value > max_val:
|
||||
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return value
|
||||
|
||||
|
||||
def prompt_choice_with_retry(
|
||||
message: str,
|
||||
choices: List[str],
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Prompt for choice input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
choices: List of valid choices
|
||||
default: Default choice (optional)
|
||||
|
||||
Returns:
|
||||
Selected choice
|
||||
"""
|
||||
while True:
|
||||
# Get input
|
||||
user_input = Prompt.ask(message, default=default)
|
||||
|
||||
# Validate choice
|
||||
if user_input not in choices:
|
||||
console.print(f"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
return user_input
|
||||
|
||||
|
||||
def prompt_int_list_with_retry(
|
||||
message: str,
|
||||
default: Optional[str] = None,
|
||||
min_val: Optional[int] = None,
|
||||
max_val: Optional[int] = None,
|
||||
validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None,
|
||||
) -> List[int]:
|
||||
"""Prompt for comma-separated integer list with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value as string (e.g., "0,1,2,3")
|
||||
min_val: Minimum allowed value for each integer (optional)
|
||||
max_val: Maximum allowed value for each integer (optional)
|
||||
validator: Custom validation function that returns (is_valid, error_message) (optional)
|
||||
|
||||
Returns:
|
||||
List of validated integers
|
||||
"""
|
||||
while True:
|
||||
# Get input
|
||||
user_input = Prompt.ask(message, default=default)
|
||||
|
||||
# Clean input: support Chinese comma and spaces
|
||||
user_input_cleaned = user_input.replace(",", ",").replace(" ", "")
|
||||
|
||||
# Try to parse as integers
|
||||
try:
|
||||
values = [int(x.strip()) for x in user_input_cleaned.split(",") if x.strip()]
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate each value's range
|
||||
invalid_values = []
|
||||
for value in values:
|
||||
if min_val is not None and value < min_val:
|
||||
invalid_values.append(value)
|
||||
elif max_val is not None and value > max_val:
|
||||
invalid_values.append(value)
|
||||
|
||||
if invalid_values:
|
||||
if min_val is not None and max_val is not None:
|
||||
console.print(f"[red]✗ Invalid value(s): {invalid_values}[/red]")
|
||||
console.print(f"[yellow]Valid range: {min_val}-{max_val}[/yellow]")
|
||||
elif min_val is not None:
|
||||
console.print(f"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]")
|
||||
elif max_val is not None:
|
||||
console.print(f"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Custom validation
|
||||
if validator is not None:
|
||||
is_valid, error_msg = validator(values)
|
||||
if not is_valid:
|
||||
console.print(f"[red]✗ {error_msg}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return values
|
||||
207
kt-kernel/python/cli/utils/kv_cache_calculator.py
Normal file
207
kt-kernel/python/cli/utils/kv_cache_calculator.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
KV Cache Size Calculator for SGLang
|
||||
|
||||
This script calculates the KV cache size in GB for a given model and number of tokens.
|
||||
It follows the same logic as in sglang/srt/model_executor/model_runner.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
# Add sglang to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
|
||||
def get_dtype_bytes(dtype_str: str) -> int:
|
||||
"""Get the number of bytes for a given dtype string."""
|
||||
dtype_map = {
|
||||
"float32": 4,
|
||||
"float16": 2,
|
||||
"bfloat16": 2,
|
||||
"float8_e4m3fn": 1,
|
||||
"float8_e5m2": 1,
|
||||
"auto": 2, # Usually defaults to bfloat16
|
||||
}
|
||||
return dtype_map.get(dtype_str, 2)
|
||||
|
||||
|
||||
def get_kv_size_gb(
|
||||
model_path: str,
|
||||
max_total_tokens: int,
|
||||
tp: int = 1,
|
||||
dtype: str = "auto",
|
||||
verbose: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Calculate the KV cache size in GB for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
max_total_tokens: Maximum number of tokens to cache
|
||||
tp: Tensor parallelism size
|
||||
dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)
|
||||
verbose: Whether to print detailed information
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing calculation details
|
||||
"""
|
||||
# Load model config
|
||||
model_config = ModelConfig(model_path, dtype=dtype)
|
||||
hf_config = model_config.hf_config
|
||||
|
||||
# Determine dtype bytes
|
||||
dtype_bytes = get_dtype_bytes(dtype)
|
||||
if dtype == "auto":
|
||||
# Auto dtype usually becomes bfloat16
|
||||
dtype_bytes = 2
|
||||
|
||||
# Number of layers
|
||||
num_layers = model_config.num_attention_layers
|
||||
|
||||
# Check if it's MLA (Multi-head Latent Attention) model
|
||||
is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA"
|
||||
|
||||
result = {
|
||||
"model_path": model_path,
|
||||
"max_total_tokens": max_total_tokens,
|
||||
"tp": tp,
|
||||
"dtype": dtype,
|
||||
"dtype_bytes": dtype_bytes,
|
||||
"num_layers": num_layers,
|
||||
"is_mla": is_mla,
|
||||
}
|
||||
|
||||
if is_mla:
|
||||
# MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)
|
||||
kv_lora_rank = model_config.kv_lora_rank
|
||||
qk_rope_head_dim = model_config.qk_rope_head_dim
|
||||
|
||||
# Calculate cell size (per token)
|
||||
cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes
|
||||
|
||||
result.update(
|
||||
{
|
||||
"kv_lora_rank": kv_lora_rank,
|
||||
"qk_rope_head_dim": qk_rope_head_dim,
|
||||
"cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Check if it's NSA (Native Sparse Attention) model
|
||||
if is_deepseek_nsa(hf_config):
|
||||
index_head_dim = get_nsa_index_head_dim(hf_config)
|
||||
indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
||||
indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)
|
||||
indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes
|
||||
cell_size += indexer_cell_size
|
||||
|
||||
result.update(
|
||||
{
|
||||
"is_nsa": True,
|
||||
"index_head_dim": index_head_dim,
|
||||
"indexer_cell_size_bytes": indexer_cell_size,
|
||||
"total_cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
else:
|
||||
result["is_nsa"] = False
|
||||
else:
|
||||
# Standard MHA models
|
||||
num_kv_heads = model_config.get_num_kv_heads(tp)
|
||||
head_dim = model_config.head_dim
|
||||
v_head_dim = model_config.v_head_dim
|
||||
|
||||
# Calculate cell size (per token)
|
||||
cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes
|
||||
|
||||
result.update(
|
||||
{
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"v_head_dim": v_head_dim,
|
||||
"cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate total KV cache size
|
||||
total_size_bytes = max_total_tokens * cell_size
|
||||
total_size_gb = total_size_bytes / (1024**3)
|
||||
|
||||
# For MHA models with separate K and V buffers
|
||||
if not is_mla:
|
||||
k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes
|
||||
v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes
|
||||
k_size_gb = k_size_bytes / (1024**3)
|
||||
v_size_gb = v_size_bytes / (1024**3)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"k_size_gb": k_size_gb,
|
||||
"v_size_gb": v_size_gb,
|
||||
}
|
||||
)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_gb": total_size_gb,
|
||||
}
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}")
|
||||
print(f"Architecture: {'MLA' if is_mla else 'MHA'}")
|
||||
print(f"Layers: {num_layers}")
|
||||
|
||||
if is_mla:
|
||||
print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}")
|
||||
if result.get("is_nsa"):
|
||||
print(f"NSA Index Head Dim: {index_head_dim}")
|
||||
print(
|
||||
f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})"
|
||||
)
|
||||
else:
|
||||
print(f"Cell size: {cell_size} bytes")
|
||||
else:
|
||||
print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}")
|
||||
print(f"Cell size: {cell_size} bytes")
|
||||
print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB")
|
||||
|
||||
print(f"Total KV Cache Size: {total_size_gb:.2f} GB")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Calculate KV cache size for a model")
|
||||
parser.add_argument("model_path", help="Path to the model")
|
||||
parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
|
||||
parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)")
|
||||
parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = get_kv_size_gb(
|
||||
args.model_path,
|
||||
args.max_total_tokens,
|
||||
tp=args.tp,
|
||||
dtype=args.dtype,
|
||||
verbose=not args.quiet,
|
||||
)
|
||||
|
||||
if args.quiet:
|
||||
print(f"{result['total_size_gb']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
250
kt-kernel/python/cli/utils/model_discovery.py
Normal file
250
kt-kernel/python/cli/utils/model_discovery.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Model Discovery Utilities
|
||||
|
||||
Shared functions for discovering and registering new models across different commands.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
|
||||
from kt_kernel.cli.utils.model_scanner import (
|
||||
discover_models,
|
||||
scan_directory_for_models,
|
||||
ScannedModel,
|
||||
)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def discover_and_register_global(
|
||||
min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = "en"
|
||||
) -> Tuple[int, int, List[UserModel]]:
|
||||
"""
|
||||
Perform global model discovery and register new models.
|
||||
|
||||
Args:
|
||||
min_size_gb: Minimum model size in GB
|
||||
max_depth: Maximum search depth
|
||||
show_progress: Whether to show progress messages
|
||||
lang: Language for messages ("en" or "zh")
|
||||
|
||||
Returns:
|
||||
Tuple of (total_found, new_found, registered_models)
|
||||
"""
|
||||
registry = UserModelRegistry()
|
||||
|
||||
if show_progress:
|
||||
if lang == "zh":
|
||||
console.print("[dim]正在扫描系统中的模型权重,这可能需要30-60秒...[/dim]")
|
||||
else:
|
||||
console.print("[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]")
|
||||
|
||||
# Global scan
|
||||
all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth)
|
||||
|
||||
# Filter out existing models
|
||||
new_models = []
|
||||
for model in all_models:
|
||||
if not registry.find_by_path(model.path):
|
||||
new_models.append(model)
|
||||
|
||||
# Register new models
|
||||
registered = []
|
||||
for model in new_models:
|
||||
user_model = _create_and_register_model(registry, model)
|
||||
if user_model:
|
||||
registered.append(user_model)
|
||||
|
||||
return len(all_models), len(new_models), registered
|
||||
|
||||
|
||||
def discover_and_register_path(
|
||||
path: str,
|
||||
min_size_gb: float = 2.0,
|
||||
existing_paths: Optional[set] = None,
|
||||
show_progress: bool = True,
|
||||
lang: str = "en",
|
||||
) -> Tuple[int, int, List[UserModel]]:
|
||||
"""
|
||||
Discover models in a specific path and register new ones.
|
||||
|
||||
Args:
|
||||
path: Directory path to scan
|
||||
min_size_gb: Minimum model file size in GB
|
||||
existing_paths: Set of already discovered paths in this session (optional)
|
||||
show_progress: Whether to show progress messages
|
||||
lang: Language for messages ("en" or "zh")
|
||||
|
||||
Returns:
|
||||
Tuple of (total_found, new_found, registered_models)
|
||||
"""
|
||||
registry = UserModelRegistry()
|
||||
|
||||
if show_progress:
|
||||
if lang == "zh":
|
||||
console.print(f"[dim]正在扫描 {path}...[/dim]")
|
||||
else:
|
||||
console.print(f"[dim]Scanning {path}...[/dim]")
|
||||
|
||||
# Scan directory
|
||||
model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb)
|
||||
|
||||
if not model_info:
|
||||
return 0, 0, []
|
||||
|
||||
# Convert to ScannedModel and filter
|
||||
new_models = []
|
||||
for dir_path, (format_type, size_bytes, file_count, files) in model_info.items():
|
||||
# Check if already in registry
|
||||
if registry.find_by_path(dir_path):
|
||||
continue
|
||||
|
||||
# Check if already discovered in this session
|
||||
if existing_paths and dir_path in existing_paths:
|
||||
continue
|
||||
|
||||
model = ScannedModel(
|
||||
path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files
|
||||
)
|
||||
new_models.append(model)
|
||||
|
||||
# Register new models
|
||||
registered = []
|
||||
for model in new_models:
|
||||
user_model = _create_and_register_model(registry, model)
|
||||
if user_model:
|
||||
registered.append(user_model)
|
||||
|
||||
return len(model_info), len(new_models), registered
|
||||
|
||||
|
||||
def _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]:
|
||||
"""
|
||||
Create a UserModel from ScannedModel and register it.
|
||||
|
||||
Handles name conflicts by suggesting a unique name (e.g., model-2, model-3).
|
||||
Automatically detects repo_id from README.md YAML frontmatter.
|
||||
Automatically detects and caches MoE information for safetensors models.
|
||||
|
||||
Args:
|
||||
registry: UserModelRegistry instance
|
||||
scanned_model: ScannedModel to register
|
||||
|
||||
Returns:
|
||||
Registered UserModel or None if failed
|
||||
"""
|
||||
# Use suggest_name to get a unique name (adds -2, -3, etc. if needed)
|
||||
unique_name = registry.suggest_name(scanned_model.folder_name)
|
||||
|
||||
user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format)
|
||||
|
||||
# Auto-detect repo_id from README.md (only YAML frontmatter)
|
||||
try:
|
||||
from kt_kernel.cli.utils.repo_detector import detect_repo_for_model
|
||||
|
||||
repo_info = detect_repo_for_model(scanned_model.path)
|
||||
if repo_info:
|
||||
repo_id, repo_type = repo_info
|
||||
user_model.repo_id = repo_id
|
||||
user_model.repo_type = repo_type
|
||||
except Exception:
|
||||
# Silently continue if detection fails
|
||||
pass
|
||||
|
||||
# Auto-detect MoE information for safetensors models
|
||||
if scanned_model.format == "safetensors":
|
||||
try:
|
||||
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
|
||||
|
||||
moe_result = analyze_moe_model(scanned_model.path, use_cache=True)
|
||||
if moe_result and moe_result.get("is_moe"):
|
||||
user_model.is_moe = True
|
||||
user_model.moe_num_experts = moe_result.get("num_experts")
|
||||
user_model.moe_num_experts_per_tok = moe_result.get("num_experts_per_tok")
|
||||
else:
|
||||
user_model.is_moe = False
|
||||
except Exception:
|
||||
# Silently continue if MoE detection fails
|
||||
# is_moe will remain None
|
||||
pass
|
||||
|
||||
try:
|
||||
registry.add_model(user_model)
|
||||
return user_model
|
||||
except Exception:
|
||||
# Should not happen since we used suggest_name, but handle gracefully
|
||||
return None
|
||||
|
||||
|
||||
def format_discovery_summary(
|
||||
total_found: int,
|
||||
new_found: int,
|
||||
registered: List[UserModel],
|
||||
lang: str = "en",
|
||||
show_models: bool = True,
|
||||
max_show: int = 10,
|
||||
) -> None:
|
||||
"""
|
||||
Print formatted discovery summary.
|
||||
|
||||
Args:
|
||||
total_found: Total models found
|
||||
new_found: New models found
|
||||
registered: List of registered UserModel objects
|
||||
lang: Language ("en" or "zh")
|
||||
show_models: Whether to show model list
|
||||
max_show: Maximum models to show
|
||||
"""
|
||||
console.print()
|
||||
|
||||
if new_found == 0:
|
||||
if total_found > 0:
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,所有模型均已在列表中")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, all already in the list")
|
||||
else:
|
||||
if lang == "zh":
|
||||
console.print("[yellow]未找到模型[/yellow]")
|
||||
else:
|
||||
console.print("[yellow]No models found[/yellow]")
|
||||
return
|
||||
|
||||
# Show summary
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,其中 {new_found} 个为新模型")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new")
|
||||
|
||||
# Show registered count
|
||||
if len(registered) > 0:
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Successfully added {len(registered)} new models to list")
|
||||
|
||||
# Show model list
|
||||
if show_models and registered:
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print(f"[dim]新发现的模型(前{max_show}个):[/dim]")
|
||||
else:
|
||||
console.print(f"[dim]Newly discovered models (first {max_show}):[/dim]")
|
||||
|
||||
for i, model in enumerate(registered[:max_show], 1):
|
||||
# Get size from registry or estimate
|
||||
size_str = "?.? GB"
|
||||
# Try to find the ScannedModel to get size
|
||||
# For now just show name and path
|
||||
console.print(f" {i}. {model.name} ({model.format})")
|
||||
console.print(f" [dim]{model.path}[/dim]")
|
||||
|
||||
if len(registered) > max_show:
|
||||
remaining = len(registered) - max_show
|
||||
if lang == "zh":
|
||||
console.print(f" [dim]... 还有 {remaining} 个新模型[/dim]")
|
||||
else:
|
||||
console.print(f" [dim]... and {remaining} more new models[/dim]")
|
||||
790
kt-kernel/python/cli/utils/model_scanner.py
Normal file
790
kt-kernel/python/cli/utils/model_scanner.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""
|
||||
Model Scanner
|
||||
|
||||
Scans directories for model files (safetensors, gguf) and identifies models
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Tuple, Dict
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScannedModel:
|
||||
"""Temporary structure for scanned model information"""
|
||||
|
||||
path: str # Absolute path to model directory
|
||||
format: str # "safetensors" | "gguf" | "mixed"
|
||||
size_bytes: int # Total size in bytes
|
||||
file_count: int # Number of model files
|
||||
files: List[str] # List of model file names
|
||||
|
||||
@property
|
||||
def size_gb(self) -> float:
|
||||
"""Get size in GB"""
|
||||
return self.size_bytes / (1024**3)
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
"""Get the folder name (default model name)"""
|
||||
return Path(self.path).name
|
||||
|
||||
|
||||
class ModelScanner:
|
||||
"""Scanner for discovering models in directory trees"""
|
||||
|
||||
def __init__(self, min_size_gb: float = 10.0):
|
||||
"""
|
||||
Initialize scanner
|
||||
|
||||
Args:
|
||||
min_size_gb: Minimum folder size in GB to be considered a model
|
||||
"""
|
||||
self.min_size_bytes = int(min_size_gb * 1024**3)
|
||||
|
||||
def scan_directory(
|
||||
self, base_path: Path, exclude_paths: Optional[Set[str]] = None
|
||||
) -> Tuple[List[ScannedModel], List[str]]:
|
||||
"""
|
||||
Scan directory tree for models
|
||||
|
||||
Args:
|
||||
base_path: Root directory to scan
|
||||
exclude_paths: Set of absolute paths to exclude from results
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_models, warnings)
|
||||
- valid_models: List of ScannedModel instances
|
||||
- warnings: List of warning messages
|
||||
"""
|
||||
if not base_path.exists():
|
||||
raise ValueError(f"Path does not exist: {base_path}")
|
||||
|
||||
if not base_path.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {base_path}")
|
||||
|
||||
exclude_paths = exclude_paths or set()
|
||||
results: List[ScannedModel] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
# Walk the directory tree
|
||||
for root, dirs, files in os.walk(base_path):
|
||||
root_path = Path(root).resolve()
|
||||
|
||||
# Skip if already registered
|
||||
if str(root_path) in exclude_paths:
|
||||
dirs[:] = [] # Don't descend into this directory
|
||||
continue
|
||||
|
||||
# Check for model files
|
||||
safetensors_files = [f for f in files if f.endswith(".safetensors")]
|
||||
gguf_files = [f for f in files if f.endswith(".gguf")]
|
||||
|
||||
if not safetensors_files and not gguf_files:
|
||||
continue # No model files in this directory
|
||||
|
||||
# Calculate total size
|
||||
model_files = safetensors_files + gguf_files
|
||||
total_size = self._calculate_total_size(root_path, model_files)
|
||||
|
||||
# Check if size meets minimum threshold
|
||||
if total_size < self.min_size_bytes:
|
||||
continue # Too small, but keep scanning subdirectories
|
||||
|
||||
# Detect format
|
||||
if safetensors_files and gguf_files:
|
||||
# Mixed format - issue warning
|
||||
warnings.append(
|
||||
f"Mixed format detected in {root_path}: "
|
||||
f"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. "
|
||||
"Please separate into different folders and re-scan."
|
||||
)
|
||||
dirs[:] = [] # Don't descend into mixed format directories
|
||||
continue
|
||||
|
||||
# Determine format
|
||||
format_type = "safetensors" if safetensors_files else "gguf"
|
||||
|
||||
# Create scanned model
|
||||
scanned = ScannedModel(
|
||||
path=str(root_path),
|
||||
format=format_type,
|
||||
size_bytes=total_size,
|
||||
file_count=len(model_files),
|
||||
files=model_files,
|
||||
)
|
||||
|
||||
results.append(scanned)
|
||||
|
||||
# Continue scanning subdirectories - they might also contain models
|
||||
# Each subdirectory will be independently checked for size >= 10GB
|
||||
|
||||
return results, warnings
|
||||
|
||||
def scan_single_path(self, path: Path) -> Optional[ScannedModel]:
|
||||
"""
|
||||
Scan a single path for model files
|
||||
|
||||
Args:
|
||||
path: Path to scan
|
||||
|
||||
Returns:
|
||||
ScannedModel instance or None if not a valid model
|
||||
"""
|
||||
if not path.exists() or not path.is_dir():
|
||||
return None
|
||||
|
||||
# Find model files
|
||||
safetensors_files = list(path.glob("*.safetensors"))
|
||||
gguf_files = list(path.glob("*.gguf"))
|
||||
|
||||
if not safetensors_files and not gguf_files:
|
||||
return None
|
||||
|
||||
# Check for mixed format
|
||||
if safetensors_files and gguf_files:
|
||||
raise ValueError(
|
||||
f"Mixed format detected: {len(safetensors_files)} safetensors + "
|
||||
f"{len(gguf_files)} gguf files. Please use a single format."
|
||||
)
|
||||
|
||||
# Calculate size
|
||||
model_files = [f.name for f in safetensors_files + gguf_files]
|
||||
total_size = self._calculate_total_size(path, model_files)
|
||||
|
||||
# Determine format
|
||||
format_type = "safetensors" if safetensors_files else "gguf"
|
||||
|
||||
return ScannedModel(
|
||||
path=str(path.resolve()),
|
||||
format=format_type,
|
||||
size_bytes=total_size,
|
||||
file_count=len(model_files),
|
||||
files=model_files,
|
||||
)
|
||||
|
||||
def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int:
|
||||
"""
|
||||
Calculate total size of specified files in directory
|
||||
|
||||
Args:
|
||||
directory: Directory containing the files
|
||||
filenames: List of filenames to sum
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
total = 0
|
||||
for filename in filenames:
|
||||
file_path = directory / filename
|
||||
if file_path.exists():
|
||||
try:
|
||||
total += file_path.stat().st_size
|
||||
except OSError:
|
||||
# File might be inaccessible, skip it
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def scan_directory(
|
||||
base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None
|
||||
) -> Tuple[List[ScannedModel], List[str]]:
|
||||
"""
|
||||
Convenience function to scan a directory
|
||||
|
||||
Args:
|
||||
base_path: Root directory to scan
|
||||
min_size_gb: Minimum folder size in GB
|
||||
exclude_paths: Set of paths to exclude
|
||||
|
||||
Returns:
|
||||
Tuple of (models, warnings)
|
||||
"""
|
||||
scanner = ModelScanner(min_size_gb=min_size_gb)
|
||||
return scanner.scan_directory(base_path, exclude_paths)
|
||||
|
||||
|
||||
def scan_single_path(path: Path) -> Optional[ScannedModel]:
|
||||
"""
|
||||
Convenience function to scan a single path
|
||||
|
||||
Args:
|
||||
path: Path to scan
|
||||
|
||||
Returns:
|
||||
ScannedModel or None
|
||||
"""
|
||||
scanner = ModelScanner()
|
||||
return scanner.scan_single_path(path)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes to human-readable string
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "42.3 GB")
|
||||
"""
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
|
||||
# ===== Fast Scanning with Find Command and Tree-based Root Detection =====
|
||||
|
||||
|
||||
def find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]:
|
||||
"""
|
||||
Use find command to quickly locate files
|
||||
|
||||
Args:
|
||||
mount_point: Starting directory
|
||||
pattern: File pattern (e.g., "config.json", "*.gguf")
|
||||
max_depth: Maximum directory depth (default: 6)
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of absolute file paths
|
||||
"""
|
||||
try:
|
||||
# Use shell=True to redirect stderr to /dev/null, ignoring permission errors
|
||||
result = subprocess.run(
|
||||
f'find "{mount_point}" -maxdepth {max_depth} -name "{pattern}" -type f 2>/dev/null',
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
shell=True,
|
||||
)
|
||||
|
||||
# Return results even if returncode is non-zero (due to permission errors)
|
||||
# As long as we got some output
|
||||
if result.stdout:
|
||||
return [line.strip() for line in result.stdout.strip().split("\n") if line.strip()]
|
||||
return []
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return []
|
||||
|
||||
|
||||
def is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if a directory is a valid model directory
|
||||
|
||||
Args:
|
||||
directory: Path to check
|
||||
min_size_gb: Minimum size in GB
|
||||
|
||||
Returns:
|
||||
(is_valid, model_type) where model_type is "safetensors", "gguf", or None
|
||||
"""
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
return False, None
|
||||
|
||||
has_config = (directory / "config.json").exists()
|
||||
safetensors_files = list(directory.glob("*.safetensors"))
|
||||
gguf_files = list(directory.glob("*.gguf"))
|
||||
|
||||
# Determine model type
|
||||
model_type = None
|
||||
if (has_config and safetensors_files) or safetensors_files:
|
||||
model_type = "safetensors"
|
||||
elif gguf_files:
|
||||
model_type = "gguf"
|
||||
else:
|
||||
return False, None
|
||||
|
||||
# Check size - only count model files (fast!)
|
||||
total_size = 0
|
||||
if model_type == "safetensors":
|
||||
for f in safetensors_files:
|
||||
try:
|
||||
total_size += f.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
else: # gguf
|
||||
for f in gguf_files:
|
||||
try:
|
||||
total_size += f.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
size_gb = total_size / (1024**3)
|
||||
if size_gb < min_size_gb:
|
||||
return False, None
|
||||
|
||||
return True, model_type
|
||||
|
||||
|
||||
def scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]:
|
||||
"""
|
||||
Fast scan for all model paths using find command
|
||||
|
||||
Args:
|
||||
mount_points: List of mount points to scan
|
||||
min_size_gb: Minimum model size in GB
|
||||
max_depth: Maximum search depth (default: 6)
|
||||
|
||||
Returns:
|
||||
List of valid model directory paths
|
||||
"""
|
||||
model_paths = set()
|
||||
|
||||
for mount in mount_points:
|
||||
if not os.path.exists(mount):
|
||||
continue
|
||||
|
||||
# Find all config.json files
|
||||
config_files = find_files_fast(mount, "config.json", max_depth=max_depth)
|
||||
for config_path in config_files:
|
||||
model_dir = Path(config_path).parent
|
||||
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
|
||||
if is_valid:
|
||||
model_paths.add(str(model_dir.resolve()))
|
||||
|
||||
# Find all *.gguf files
|
||||
gguf_files = find_files_fast(mount, "*.gguf", max_depth=max_depth)
|
||||
for gguf_path in gguf_files:
|
||||
model_dir = Path(gguf_path).parent
|
||||
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
|
||||
if is_valid:
|
||||
model_paths.add(str(model_dir.resolve()))
|
||||
|
||||
return sorted(model_paths)
|
||||
|
||||
|
||||
def get_root_subdirs() -> List[str]:
|
||||
"""
|
||||
Get subdirectories of / that are worth scanning
|
||||
|
||||
Filters out system paths only
|
||||
|
||||
Returns:
|
||||
List of directories to scan
|
||||
"""
|
||||
# System paths to exclude
|
||||
excluded = {
|
||||
"dev",
|
||||
"proc",
|
||||
"sys",
|
||||
"run",
|
||||
"boot",
|
||||
"tmp",
|
||||
"usr",
|
||||
"lib",
|
||||
"lib64",
|
||||
"bin",
|
||||
"sbin",
|
||||
"etc",
|
||||
"opt",
|
||||
"var",
|
||||
"snap",
|
||||
}
|
||||
|
||||
scan_dirs = []
|
||||
|
||||
try:
|
||||
for entry in os.scandir("/"):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Skip excluded paths
|
||||
if entry.name in excluded:
|
||||
continue
|
||||
|
||||
scan_dirs.append(entry.path)
|
||||
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return sorted(scan_dirs)
|
||||
|
||||
|
||||
def scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]:
|
||||
"""
|
||||
Scan a directory for models using find command with size filter
|
||||
|
||||
Uses find -size +2G to only locate large model files (>=2GB)
|
||||
|
||||
Args:
|
||||
directory: Directory to scan
|
||||
min_file_size_gb: Minimum individual file size in GB (default: 2.0)
|
||||
|
||||
Returns:
|
||||
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
|
||||
"""
|
||||
model_info = {}
|
||||
|
||||
# Convert GB to find's format (e.g., 2GB = +2G)
|
||||
if min_file_size_gb >= 1.0:
|
||||
size_filter = f"+{int(min_file_size_gb)}G"
|
||||
else:
|
||||
size_mb = int(min_file_size_gb * 1024)
|
||||
size_filter = f"+{size_mb}M"
|
||||
|
||||
# 1. Find *.gguf files >= 2GB
|
||||
gguf_cmd = f'find "{directory}" -name "*.gguf" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
|
||||
result = subprocess.run(gguf_cmd, shell=True, capture_output=True, text=True, timeout=120)
|
||||
|
||||
# Group by directory
|
||||
gguf_dirs = defaultdict(list)
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("\t")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
file_path, size_str = parts
|
||||
file_path_obj = Path(file_path)
|
||||
dir_path = str(file_path_obj.parent)
|
||||
gguf_dirs[dir_path].append((file_path_obj.name, int(size_str)))
|
||||
|
||||
# Add all gguf directories
|
||||
for dir_path, files in gguf_dirs.items():
|
||||
total_size = sum(size for _, size in files)
|
||||
model_info[dir_path] = ("gguf", total_size, len(files), [name for name, _ in files])
|
||||
|
||||
# 2. Find *.safetensors files >= 2GB
|
||||
safetensors_cmd = (
|
||||
f'find "{directory}" -name "*.safetensors" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
|
||||
)
|
||||
result = subprocess.run(safetensors_cmd, shell=True, capture_output=True, text=True, timeout=120)
|
||||
|
||||
# Group by directory
|
||||
safetensors_dirs = defaultdict(list)
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("\t")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
file_path, size_str = parts
|
||||
file_path_obj = Path(file_path)
|
||||
dir_path = str(file_path_obj.parent)
|
||||
safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str)))
|
||||
|
||||
# 3. Check each safetensors directory for config.json
|
||||
for dir_path, files in safetensors_dirs.items():
|
||||
if os.path.exists(os.path.join(dir_path, "config.json")):
|
||||
total_size = sum(size for _, size in files)
|
||||
model_info[dir_path] = ("safetensors", total_size, len(files), [name for name, _ in files])
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def scan_all_models_with_info(
|
||||
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
|
||||
) -> Dict[str, tuple]:
|
||||
"""
|
||||
Fast scan with parallel directory scanning
|
||||
|
||||
Strategy:
|
||||
1. Use provided directories or auto-detect root subdirectories
|
||||
2. Scan each directory in parallel (one thread per directory)
|
||||
3. Use find -size +2G to find large model files (>=2GB)
|
||||
|
||||
Args:
|
||||
mount_points: Specific directories to scan, or None to auto-detect from / subdirs
|
||||
min_size_gb: Not used anymore (kept for API compatibility)
|
||||
max_depth: Not used anymore (kept for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Get directories to scan
|
||||
if mount_points is None:
|
||||
# Get root subdirectories (exclude system paths)
|
||||
scan_dirs = get_root_subdirs()
|
||||
else:
|
||||
scan_dirs = mount_points
|
||||
|
||||
if not scan_dirs:
|
||||
return {}
|
||||
|
||||
model_info = {}
|
||||
|
||||
# Scan each directory in parallel (max 8 concurrent)
|
||||
# Use 2GB threshold to find model files
|
||||
with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor:
|
||||
futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs}
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
dir_results = future.result()
|
||||
model_info.update(dir_results)
|
||||
except Exception as e:
|
||||
# Skip directories with errors
|
||||
pass
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]:
|
||||
"""
|
||||
Find optimal root paths from model paths using tree-based algorithm
|
||||
|
||||
Algorithm:
|
||||
1. Build path tree with all intermediate paths
|
||||
2. DFS to calculate f(x) = subtree sum (number of models in subtree)
|
||||
3. Find roots where f(parent) = f(x) > max(f(children))
|
||||
|
||||
Args:
|
||||
model_paths: List of model directory paths
|
||||
|
||||
Returns:
|
||||
(root_paths, subtree_sizes) where:
|
||||
- root_paths: List of inferred root directories
|
||||
- subtree_sizes: Dict mapping each root to number of models
|
||||
"""
|
||||
if not model_paths:
|
||||
return [], {}
|
||||
|
||||
# 1. Build path set (including all intermediate paths)
|
||||
all_paths = set()
|
||||
model_set = set(model_paths)
|
||||
|
||||
for model_path in model_paths:
|
||||
path = Path(model_path)
|
||||
for i in range(1, len(path.parts) + 1):
|
||||
all_paths.add(str(Path(*path.parts[:i])))
|
||||
|
||||
# 2. Build parent-child relationships
|
||||
children_map = defaultdict(list)
|
||||
for path in all_paths:
|
||||
path_obj = Path(path)
|
||||
if len(path_obj.parts) > 1:
|
||||
parent = str(path_obj.parent)
|
||||
if parent in all_paths:
|
||||
children_map[parent].append(path)
|
||||
|
||||
# 3. DFS to calculate f(x) and max_child_f(x)
|
||||
f = {} # path -> subtree sum
|
||||
max_child_f = {} # path -> max(f(children))
|
||||
visited = set()
|
||||
|
||||
def dfs(path: str) -> int:
|
||||
if path in visited:
|
||||
return f[path]
|
||||
visited.add(path)
|
||||
|
||||
# Current node weight (1 if it's a model path, 0 otherwise)
|
||||
weight = 1 if path in model_set else 0
|
||||
|
||||
# Recursively calculate children
|
||||
children = children_map.get(path, [])
|
||||
if not children:
|
||||
# Leaf node
|
||||
f[path] = weight
|
||||
max_child_f[path] = 0
|
||||
return weight
|
||||
|
||||
# Calculate f values for all children
|
||||
children_f_values = [dfs(child) for child in children]
|
||||
|
||||
# Calculate f(x) and max_child_f(x)
|
||||
f[path] = weight + sum(children_f_values)
|
||||
max_child_f[path] = max(children_f_values) if children_f_values else 0
|
||||
|
||||
return f[path]
|
||||
|
||||
# Find top-level nodes (no parent in all_paths)
|
||||
top_nodes = []
|
||||
for path in all_paths:
|
||||
parent = str(Path(path).parent)
|
||||
if parent not in all_paths or parent == path:
|
||||
top_nodes.append(path)
|
||||
|
||||
# Execute DFS from all top nodes
|
||||
for top in top_nodes:
|
||||
dfs(top)
|
||||
|
||||
# 4. Find root nodes: f(parent) = f(x) >= max(f(children))
|
||||
# Note: Use >= instead of > to handle the case where a directory contains only one model
|
||||
candidate_roots = []
|
||||
for path in all_paths:
|
||||
# Skip model paths themselves (leaf nodes in model tree)
|
||||
if path in model_set:
|
||||
continue
|
||||
|
||||
parent = str(Path(path).parent)
|
||||
|
||||
# Check condition: f(parent) = f(x) and f(x) >= max(f(children))
|
||||
if parent in f and f.get(parent, 0) == f.get(path, 0):
|
||||
if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0:
|
||||
candidate_roots.append(path)
|
||||
|
||||
# 5. Remove redundant roots (prefer deeper paths)
|
||||
# If a root is an ancestor of another root with the same f value, remove it
|
||||
roots = []
|
||||
candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts))
|
||||
|
||||
for root in candidate_roots_sorted:
|
||||
# Check if this root is a parent of any already selected root
|
||||
is_redundant = False
|
||||
for selected in roots:
|
||||
if selected.startswith(root + "/"):
|
||||
# selected is a child of root
|
||||
# Only keep root if it has more models (shouldn't happen by algorithm)
|
||||
if f.get(root, 0) == f.get(selected, 0):
|
||||
is_redundant = True
|
||||
break
|
||||
|
||||
if not is_redundant:
|
||||
# Also filter out very shallow paths (< 3 levels)
|
||||
if len(Path(root).parts) >= 3:
|
||||
roots.append(root)
|
||||
|
||||
# Build subtree sizes for roots
|
||||
subtree_sizes = {root: f.get(root, 0) for root in roots}
|
||||
|
||||
return sorted(roots), subtree_sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRootInfo:
|
||||
"""Information about a detected model root path"""
|
||||
|
||||
path: str
|
||||
model_count: int
|
||||
models: List[ScannedModel]
|
||||
|
||||
|
||||
def discover_models(
|
||||
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
|
||||
) -> List[ScannedModel]:
|
||||
"""
|
||||
Discover all model directories on the system
|
||||
|
||||
Fast scan using find command to locate all models that meet the criteria
|
||||
|
||||
Args:
|
||||
mount_points: List of mount points to scan (None = auto-detect)
|
||||
min_size_gb: Minimum model size in GB (default: 10.0)
|
||||
max_depth: Maximum search depth (default: 6)
|
||||
|
||||
Returns:
|
||||
List of ScannedModel sorted by path
|
||||
"""
|
||||
# Auto-detect mount points if not provided
|
||||
if mount_points is None:
|
||||
mount_points = _get_mount_points()
|
||||
|
||||
# Fast scan with cached info (only scan once!)
|
||||
model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth)
|
||||
|
||||
if not model_info:
|
||||
return []
|
||||
|
||||
# Convert to ScannedModel objects
|
||||
results = []
|
||||
for model_path, (model_type, total_size, file_count, files) in model_info.items():
|
||||
results.append(
|
||||
ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files)
|
||||
)
|
||||
|
||||
# Sort by path
|
||||
results.sort(key=lambda m: m.path)
|
||||
return results
|
||||
|
||||
|
||||
def _get_mount_points() -> List[str]:
|
||||
"""
|
||||
Get all valid mount points from /proc/mounts, filtering out system paths
|
||||
|
||||
Returns:
|
||||
List of mount point paths suitable for model storage
|
||||
(excludes root "/" to avoid scanning entire filesystem)
|
||||
"""
|
||||
mount_points = set()
|
||||
|
||||
# System paths to exclude (unlikely to contain model files)
|
||||
excluded_paths = [
|
||||
"/snap/",
|
||||
"/proc/",
|
||||
"/sys/",
|
||||
"/run/",
|
||||
"/boot",
|
||||
"/dev/",
|
||||
"/usr",
|
||||
"/lib",
|
||||
"/lib64",
|
||||
"/bin",
|
||||
"/sbin",
|
||||
"/etc",
|
||||
"/opt",
|
||||
"/var",
|
||||
"/tmp",
|
||||
]
|
||||
|
||||
try:
|
||||
with open("/proc/mounts", "r") as f:
|
||||
for line in f:
|
||||
parts = line.split()
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
|
||||
device, mount_point, fs_type = parts[0], parts[1], parts[2]
|
||||
|
||||
# Filter out pseudo filesystems
|
||||
pseudo_fs = {
|
||||
"proc",
|
||||
"sysfs",
|
||||
"devpts",
|
||||
"tmpfs",
|
||||
"devtmpfs",
|
||||
"cgroup",
|
||||
"cgroup2",
|
||||
"pstore",
|
||||
"bpf",
|
||||
"tracefs",
|
||||
"debugfs",
|
||||
"hugetlbfs",
|
||||
"mqueue",
|
||||
"configfs",
|
||||
"securityfs",
|
||||
"fuse.gvfsd-fuse",
|
||||
"fusectl",
|
||||
"squashfs",
|
||||
"overlay", # snap packages
|
||||
}
|
||||
|
||||
if fs_type in pseudo_fs:
|
||||
continue
|
||||
|
||||
# Skip root directory (too large to scan)
|
||||
if mount_point == "/":
|
||||
continue
|
||||
|
||||
# Filter out system paths
|
||||
if any(mount_point.startswith(x) for x in excluded_paths):
|
||||
continue
|
||||
|
||||
# Only include if it exists and is readable
|
||||
if os.path.exists(mount_point) and os.access(mount_point, os.R_OK):
|
||||
mount_points.add(mount_point)
|
||||
|
||||
# If no mount points found, add common data directories
|
||||
if not mount_points:
|
||||
# Add /home if it exists and is not already a separate mount point
|
||||
common_paths = ["/home", "/data", "/mnt"]
|
||||
for path in common_paths:
|
||||
if os.path.exists(path) and os.access(path, os.R_OK):
|
||||
mount_points.add(path)
|
||||
|
||||
except (FileNotFoundError, PermissionError):
|
||||
# Fallback to common paths
|
||||
mount_points = {"/home", "/mnt", "/data"}
|
||||
|
||||
return sorted(mount_points)
|
||||
254
kt-kernel/python/cli/utils/model_table_builder.py
Normal file
254
kt-kernel/python/cli/utils/model_table_builder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Shared model table builders for consistent UI across commands.
|
||||
|
||||
Provides reusable table construction functions for displaying models
|
||||
in kt model list, kt quant, kt run, etc.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
import json
|
||||
|
||||
|
||||
def format_model_size(model_path: Path, format_type: str) -> str:
|
||||
"""Calculate and format model size."""
|
||||
from kt_kernel.cli.utils.model_scanner import format_size
|
||||
|
||||
try:
|
||||
if format_type == "safetensors":
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
elif format_type == "gguf":
|
||||
files = list(model_path.glob("*.gguf"))
|
||||
else:
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
total_size = sum(f.stat().st_size for f in files if f.exists())
|
||||
return format_size(total_size)
|
||||
except Exception:
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
|
||||
def format_repo_info(model) -> str:
|
||||
"""Format repository information."""
|
||||
if model.repo_id:
|
||||
repo_abbr = "hf" if model.repo_type == "huggingface" else "ms"
|
||||
return f"{repo_abbr}:{model.repo_id}"
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
|
||||
def format_sha256_status(model, status_map: dict) -> str:
|
||||
"""Format SHA256 verification status."""
|
||||
return status_map.get(model.sha256_status or "not_checked", "[dim]?[/dim]")
|
||||
|
||||
|
||||
def build_moe_gpu_table(
|
||||
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build MoE GPU models table.
|
||||
|
||||
Args:
|
||||
models: List of MoE GPU model objects
|
||||
status_map: SHA256_STATUS_MAP for formatting status
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Exps", justify="center", style="yellow")
|
||||
table.add_column("Act", justify="center", style="green")
|
||||
table.add_column("Repository", style="dim", overflow="fold")
|
||||
table.add_column("SHA256", justify="center")
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "safetensors")
|
||||
|
||||
# MoE info
|
||||
num_experts = str(model.moe_num_experts) if model.moe_num_experts else "[dim]-[/dim]"
|
||||
num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else "[dim]-[/dim]"
|
||||
|
||||
# Repository and SHA256
|
||||
repo_str = format_repo_info(model)
|
||||
sha256_str = format_sha256_status(model, status_map)
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
return table, displayed_models
|
||||
|
||||
|
||||
def build_amx_table(
|
||||
models: List,
|
||||
status_map: dict = None, # Kept for API compatibility but not used
|
||||
show_index: bool = True,
|
||||
start_index: int = 1,
|
||||
show_linked_gpus: bool = False,
|
||||
gpu_models: Optional[List] = None,
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build AMX models table.
|
||||
|
||||
Note: AMX models are locally quantized, so no SHA256 verification column.
|
||||
|
||||
Args:
|
||||
models: List of AMX model objects
|
||||
status_map: (Unused - kept for API compatibility)
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
show_linked_gpus: Whether to show sub-rows for linked GPU models
|
||||
gpu_models: List of GPU models (required if show_linked_gpus=True)
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Method", justify="center", style="yellow")
|
||||
table.add_column("NUMA", justify="center", style="green")
|
||||
table.add_column("Source", style="dim", overflow="fold")
|
||||
|
||||
# Build reverse map if needed
|
||||
amx_used_by_gpu = {}
|
||||
if show_linked_gpus and gpu_models:
|
||||
for model in models:
|
||||
if model.gpu_model_ids:
|
||||
gpu_names = []
|
||||
for gpu_id in model.gpu_model_ids:
|
||||
for gpu_model in gpu_models:
|
||||
if gpu_model.id == gpu_id:
|
||||
gpu_names.append(gpu_model.name)
|
||||
break
|
||||
if gpu_names:
|
||||
amx_used_by_gpu[model.id] = gpu_names
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "safetensors")
|
||||
|
||||
# Read metadata from config.json or UserModel fields
|
||||
method_from_config = None
|
||||
numa_from_config = None
|
||||
try:
|
||||
config_path = Path(model.path) / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
amx_quant = config.get("amx_quantization", {})
|
||||
if amx_quant.get("converted"):
|
||||
method_from_config = amx_quant.get("method")
|
||||
numa_from_config = amx_quant.get("numa_count")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Priority: UserModel fields > config.json > ?
|
||||
method_display = (
|
||||
model.amx_quant_method.upper()
|
||||
if model.amx_quant_method
|
||||
else method_from_config.upper() if method_from_config else "[dim]?[/dim]"
|
||||
)
|
||||
numa_display = (
|
||||
str(model.amx_numa_nodes)
|
||||
if model.amx_numa_nodes
|
||||
else str(numa_from_config) if numa_from_config else "[dim]?[/dim]"
|
||||
)
|
||||
source_display = model.amx_source_model or "[dim]-[/dim]"
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, method_display, numa_display, source_display])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
# Add sub-row showing linked GPUs
|
||||
if show_linked_gpus and model.id in amx_used_by_gpu:
|
||||
gpu_list = amx_used_by_gpu[model.id]
|
||||
gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list])
|
||||
sub_row = []
|
||||
if show_index:
|
||||
sub_row.append("")
|
||||
sub_row.extend([f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""])
|
||||
table.add_row(*sub_row, style="dim")
|
||||
|
||||
return table, displayed_models
|
||||
|
||||
|
||||
def build_gguf_table(
|
||||
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build GGUF models table.
|
||||
|
||||
Args:
|
||||
models: List of GGUF model objects
|
||||
status_map: SHA256_STATUS_MAP for formatting status
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Repository", style="dim", overflow="fold")
|
||||
table.add_column("SHA256", justify="center")
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "gguf")
|
||||
|
||||
# Repository and SHA256
|
||||
repo_str = format_repo_info(model)
|
||||
sha256_str = format_sha256_status(model, status_map)
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, repo_str, sha256_str])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
return table, displayed_models
|
||||
918
kt-kernel/python/cli/utils/model_verifier.py
Normal file
918
kt-kernel/python/cli/utils/model_verifier.py
Normal file
@@ -0,0 +1,918 @@
|
||||
"""
|
||||
Model Verifier
|
||||
|
||||
SHA256 verification for model integrity
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal, Tuple
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:
|
||||
"""
|
||||
Compute SHA256 for a single file (worker function for multiprocessing).
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Tuple of (filename, sha256_hash, file_size_mb)
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Read file in chunks to handle large files
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks
|
||||
sha256_hash.update(byte_block)
|
||||
|
||||
return file_path.name, sha256_hash.hexdigest(), file_size_mb
|
||||
|
||||
|
||||
def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if HuggingFace is accessible.
|
||||
|
||||
Args:
|
||||
timeout: Connection timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (is_accessible, message)
|
||||
"""
|
||||
test_url = "https://huggingface.co"
|
||||
|
||||
try:
|
||||
response = requests.head(test_url, timeout=timeout, allow_redirects=True)
|
||||
if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible"
|
||||
return True, "HuggingFace is accessible"
|
||||
except requests.exceptions.Timeout:
|
||||
return False, f"Connection to {test_url} timed out"
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False, f"Cannot connect to {test_url}"
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False, f"Connection error: {str(e)}"
|
||||
|
||||
return False, "Unknown connection error"
|
||||
|
||||
|
||||
def verify_model_integrity(
|
||||
repo_type: Literal["huggingface", "modelscope"],
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
progress_callback=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify local model integrity against remote repository SHA256 hashes.
|
||||
|
||||
Verifies all important files:
|
||||
- *.safetensors (weights)
|
||||
- *.json (config files)
|
||||
- *.py (custom model code)
|
||||
|
||||
Args:
|
||||
repo_type: Type of repository ("huggingface" or "modelscope")
|
||||
repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3")
|
||||
local_dir: Local directory containing model files
|
||||
progress_callback: Optional callback function(message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
Dictionary with verification results:
|
||||
{
|
||||
"status": "passed" | "failed" | "error",
|
||||
"files_checked": int,
|
||||
"files_passed": int,
|
||||
"files_failed": [list of filenames],
|
||||
"error_message": str (optional)
|
||||
}
|
||||
"""
|
||||
|
||||
def report_progress(msg: str):
|
||||
"""Helper to report progress"""
|
||||
if progress_callback:
|
||||
progress_callback(msg)
|
||||
|
||||
try:
|
||||
# Convert repo_type to platform format
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
|
||||
# 1. Fetch official SHA256 hashes from remote
|
||||
report_progress("Fetching official SHA256 hashes from remote repository...")
|
||||
official_hashes = fetch_model_sha256(repo_id, platform)
|
||||
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
|
||||
|
||||
if not official_hashes:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No verifiable files found in remote repository: {repo_id}",
|
||||
}
|
||||
|
||||
# 2. Calculate local SHA256 hashes with progress
|
||||
report_progress(f"Calculating SHA256 for local files...")
|
||||
|
||||
# Get all local files matching the patterns
|
||||
local_files = []
|
||||
for pattern in ["*.safetensors", "*.json", "*.py"]:
|
||||
local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])
|
||||
|
||||
if not local_files:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No verifiable files found in local directory: {local_dir}",
|
||||
}
|
||||
|
||||
# Calculate hashes for all files
|
||||
local_hashes = calculate_local_sha256(
|
||||
local_dir,
|
||||
file_pattern="*.safetensors", # Unused when files_list is provided
|
||||
progress_callback=report_progress,
|
||||
files_list=local_files,
|
||||
)
|
||||
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
|
||||
|
||||
# 3. Compare hashes with progress
|
||||
report_progress(f"Comparing {len(official_hashes)} files...")
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
|
||||
# Handle potential path separators in filename
|
||||
file_basename = Path(filename).name
|
||||
|
||||
# Try to find the file in local hashes
|
||||
local_hash = None
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING")
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH")
|
||||
else:
|
||||
files_passed += 1
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}")
|
||||
|
||||
# 4. Return results
|
||||
total_checked = len(official_hashes)
|
||||
|
||||
if files_failed or files_missing:
|
||||
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
|
||||
return {
|
||||
"status": "failed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": all_failed,
|
||||
"error_message": f"{len(all_failed)} file(s) failed verification",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "passed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": [],
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope",
|
||||
"is_network_error": False,
|
||||
}
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.RequestException,
|
||||
) as e:
|
||||
# Network-related errors - suggest mirror
|
||||
error_msg = f"Network error: {str(e)}"
|
||||
if repo_type == "huggingface":
|
||||
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": error_msg,
|
||||
"is_network_error": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Verification failed: {str(e)}",
|
||||
"is_network_error": False,
|
||||
}
|
||||
|
||||
|
||||
def calculate_local_sha256(
|
||||
local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Calculate SHA256 hashes for files in a directory using parallel processing.
|
||||
|
||||
Args:
|
||||
local_dir: Directory to scan
|
||||
file_pattern: Glob pattern for files to hash (ignored if files_list is provided)
|
||||
progress_callback: Optional callback function(message: str) for progress updates
|
||||
files_list: Optional pre-filtered list of files to hash (overrides file_pattern)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filename to SHA256 hash
|
||||
"""
|
||||
result = {}
|
||||
|
||||
if not local_dir.exists():
|
||||
return result
|
||||
|
||||
# Get all files first to report total
|
||||
if files_list is not None:
|
||||
files_to_hash = files_list
|
||||
else:
|
||||
files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]
|
||||
total_files = len(files_to_hash)
|
||||
|
||||
if total_files == 0:
|
||||
return result
|
||||
|
||||
# Use min(16, total_files) workers to avoid over-spawning processes
|
||||
max_workers = min(16, total_files)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation")
|
||||
|
||||
# Use ProcessPoolExecutor for CPU-intensive SHA256 computation
|
||||
completed_count = 0
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(future_to_file):
|
||||
completed_count += 1
|
||||
try:
|
||||
filename, sha256_hash, file_size_mb = future.result()
|
||||
result[filename] = sha256_hash
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
file_path = future_to_file[future]
|
||||
if progress_callback:
|
||||
progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_model_sha256(
|
||||
repo_id: str,
|
||||
platform: Literal["hf", "ms"],
|
||||
revision: str | None = None,
|
||||
use_mirror: bool = False,
|
||||
timeout: int | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
获取模型仓库中所有重要文件的 sha256 哈希值。
|
||||
|
||||
包括:
|
||||
- *.safetensors (权重文件)
|
||||
- *.json (配置文件:config.json, tokenizer_config.json 等)
|
||||
- *.py (自定义模型代码:modeling.py, configuration.py 等)
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID,例如 "Qwen/Qwen3-30B-A3B"
|
||||
platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope)
|
||||
revision: 版本/分支,默认 HuggingFace 为 "main",ModelScope 为 "master"
|
||||
use_mirror: 是否使用镜像(仅对 HuggingFace 有效)
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
|
||||
Returns:
|
||||
dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."}
|
||||
"""
|
||||
if platform == "hf":
|
||||
# 先尝试直连,失败后自动使用镜像
|
||||
try:
|
||||
if use_mirror:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
|
||||
else:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout)
|
||||
except Exception as e:
|
||||
# 如果不是镜像模式且失败了,自动重试使用镜像
|
||||
if not use_mirror:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
|
||||
else:
|
||||
raise e
|
||||
elif platform == "ms":
|
||||
return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout)
|
||||
else:
|
||||
raise ValueError(f"不支持的平台: {platform},请使用 'hf' 或 'ms'")
|
||||
|
||||
|
||||
def _fetch_from_huggingface(
|
||||
repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None
|
||||
) -> dict[str, str]:
|
||||
"""从 HuggingFace 获取所有重要文件的 sha256
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID
|
||||
revision: 版本/分支
|
||||
use_mirror: 是否使用镜像(hf-mirror.com)
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
"""
|
||||
import os
|
||||
import socket
|
||||
|
||||
# 如果需要使用镜像,设置环境变量
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
# Set socket timeout if specified
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
if timeout is not None:
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
from huggingface_hub import HfApi, list_repo_files
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
all_files = list_repo_files(repo_id=repo_id, revision=revision)
|
||||
|
||||
# 筛选重要文件:*.safetensors, *.json, *.py
|
||||
important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))]
|
||||
|
||||
if not important_files:
|
||||
return {}
|
||||
|
||||
paths_info = api.get_paths_info(
|
||||
repo_id=repo_id,
|
||||
paths=important_files,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
result = {}
|
||||
for file_info in paths_info:
|
||||
if hasattr(file_info, "lfs") and file_info.lfs is not None:
|
||||
sha256 = file_info.lfs.sha256
|
||||
else:
|
||||
sha256 = getattr(file_info, "blob_id", None)
|
||||
result[file_info.path] = sha256
|
||||
|
||||
return result
|
||||
finally:
|
||||
# 恢复原始 socket timeout
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
|
||||
# 恢复原始环境变量
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
|
||||
def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:
|
||||
"""从 ModelScope 获取所有重要文件的 sha256
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID
|
||||
revision: 版本/分支
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
"""
|
||||
import socket
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
# Set socket timeout if specified
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
if timeout is not None:
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
try:
|
||||
api = HubApi()
|
||||
files_info = api.get_model_files(model_id=repo_id, revision=revision)
|
||||
|
||||
result = {}
|
||||
for file_info in files_info:
|
||||
filename = file_info.get("Name", file_info.get("Path", ""))
|
||||
# 筛选重要文件:*.safetensors, *.json, *.py
|
||||
if filename.endswith((".safetensors", ".json", ".py")):
|
||||
sha256 = file_info.get("Sha256", file_info.get("sha256", None))
|
||||
result[filename] = sha256
|
||||
|
||||
return result
|
||||
finally:
|
||||
# 恢复原始 socket timeout
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
|
||||
|
||||
def verify_model_integrity_with_progress(
|
||||
repo_type: Literal["huggingface", "modelscope"],
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
progress_callback=None,
|
||||
verbose: bool = False,
|
||||
use_mirror: bool = False,
|
||||
files_to_verify: list[str] | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify model integrity with enhanced progress reporting for Rich Progress bars.
|
||||
|
||||
This is a wrapper around verify_model_integrity() that provides more detailed
|
||||
progress information suitable for progress bar display.
|
||||
|
||||
The progress_callback receives:
|
||||
- (message: str, total: int, current: int) for countable operations
|
||||
- (message: str) for status updates
|
||||
|
||||
Args:
|
||||
repo_type: Repository type ("huggingface" or "modelscope")
|
||||
repo_id: Repository ID
|
||||
local_dir: Local directory path
|
||||
progress_callback: Optional callback for progress updates
|
||||
verbose: If True, output detailed SHA256 comparison for each file
|
||||
use_mirror: If True, use HuggingFace mirror (hf-mirror.com)
|
||||
files_to_verify: Optional list of specific files to verify (for re-verification)
|
||||
timeout: Network request timeout in seconds (None = no timeout)
|
||||
"""
|
||||
|
||||
def report_progress(msg: str, total=None, current=None):
|
||||
"""Enhanced progress reporter"""
|
||||
if progress_callback:
|
||||
progress_callback(msg, total, current)
|
||||
|
||||
try:
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
|
||||
# 1. Fetch official SHA256 hashes
|
||||
if files_to_verify:
|
||||
report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...")
|
||||
elif use_mirror and platform == "hf":
|
||||
report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...")
|
||||
else:
|
||||
report_progress("Fetching official SHA256 hashes from remote repository...")
|
||||
|
||||
official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)
|
||||
|
||||
# Filter to only requested files if specified
|
||||
if files_to_verify:
|
||||
# Extract clean filenames from files_to_verify (remove markers like "(missing)")
|
||||
clean_filenames = set()
|
||||
for f in files_to_verify:
|
||||
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
|
||||
# Ensure we only use the filename, not full path
|
||||
clean_filenames.add(Path(clean_f).name)
|
||||
|
||||
# Filter official_hashes to only include requested files
|
||||
# Compare using basename since official_hashes keys might have paths
|
||||
official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}
|
||||
|
||||
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
|
||||
|
||||
if not official_hashes:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No safetensors files found in remote repository: {repo_id}",
|
||||
}
|
||||
|
||||
# 2. Calculate local SHA256 hashes
|
||||
local_dir_path = Path(local_dir)
|
||||
|
||||
# Only hash the files we need to verify
|
||||
if files_to_verify:
|
||||
# Extract clean filenames (without markers)
|
||||
clean_filenames = set()
|
||||
for f in files_to_verify:
|
||||
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
|
||||
# Ensure we only use the filename, not full path
|
||||
clean_filenames.add(Path(clean_f).name)
|
||||
|
||||
# Only hash files that match the clean filenames
|
||||
files_to_hash = [
|
||||
f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames
|
||||
]
|
||||
else:
|
||||
files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()]
|
||||
|
||||
total_files = len(files_to_hash)
|
||||
|
||||
if files_to_verify:
|
||||
report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0)
|
||||
else:
|
||||
report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0)
|
||||
|
||||
# Progress wrapper for hashing
|
||||
completed_count = [0] # Use list for mutable closure
|
||||
|
||||
def hash_progress_callback(msg: str):
|
||||
if "Using" in msg and "workers" in msg:
|
||||
report_progress(msg)
|
||||
elif "[" in msg and "/" in msg and "]" in msg:
|
||||
# Progress update like: [1/10] ✓ filename (123.4 MB)
|
||||
completed_count[0] += 1
|
||||
report_progress(msg, total=total_files, current=completed_count[0])
|
||||
|
||||
# Pass the pre-filtered files_to_hash list
|
||||
local_hashes = calculate_local_sha256(
|
||||
local_dir_path,
|
||||
"*.safetensors",
|
||||
progress_callback=hash_progress_callback,
|
||||
files_list=files_to_hash if files_to_verify else None,
|
||||
)
|
||||
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
|
||||
|
||||
# 3. Compare hashes
|
||||
report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0)
|
||||
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
|
||||
file_basename = Path(filename).name
|
||||
|
||||
# Find matching local file
|
||||
local_hash = None
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: <missing>",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
files_passed += 1
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx
|
||||
)
|
||||
|
||||
# 4. Return results
|
||||
total_checked = len(official_hashes)
|
||||
|
||||
if files_failed or files_missing:
|
||||
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
|
||||
return {
|
||||
"status": "failed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": all_failed,
|
||||
"error_message": f"{len(all_failed)} file(s) failed verification",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "passed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": [],
|
||||
}
|
||||
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.RequestException,
|
||||
TimeoutError, # Socket timeout from socket.setdefaulttimeout()
|
||||
OSError, # Network-related OS errors
|
||||
) as e:
|
||||
error_msg = f"Network error: {str(e)}"
|
||||
if repo_type == "huggingface":
|
||||
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": error_msg,
|
||||
"is_network_error": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Verification failed: {str(e)}",
|
||||
"is_network_error": False,
|
||||
}
|
||||
|
||||
|
||||
def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None:
|
||||
"""Pre-operation verification of model integrity.
|
||||
|
||||
Can be used before running or quantizing models to ensure integrity.
|
||||
|
||||
Args:
|
||||
user_model: UserModel object to verify
|
||||
user_registry: UserModelRegistry instance
|
||||
operation_name: Name of the operation (e.g., "running", "quantizing")
|
||||
"""
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step
|
||||
import typer
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Check if already verified
|
||||
if user_model.sha256_status == "passed":
|
||||
console.print()
|
||||
print_info("Model integrity already verified ✓")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Model not verified yet
|
||||
console.print()
|
||||
console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]")
|
||||
console.print()
|
||||
|
||||
# Check if repo_id exists
|
||||
if not user_model.repo_id:
|
||||
# No repo_id - ask user to provide one
|
||||
console.print("[yellow]No repository ID configured for this model.[/yellow]")
|
||||
console.print()
|
||||
console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask("Would you like to configure repository ID now?", default=True):
|
||||
console.print()
|
||||
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Ask for repo type
|
||||
console.print()
|
||||
console.print("Repository type:")
|
||||
console.print(" [cyan][1][/cyan] HuggingFace")
|
||||
console.print(" [cyan][2][/cyan] ModelScope")
|
||||
console.print()
|
||||
|
||||
repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1")
|
||||
repo_type = "huggingface" if repo_type_choice == "1" else "modelscope"
|
||||
|
||||
# Ask for repo_id
|
||||
console.print()
|
||||
repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)")
|
||||
|
||||
# Update model
|
||||
user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id})
|
||||
user_model.repo_type = repo_type
|
||||
user_model.repo_id = repo_id
|
||||
|
||||
console.print()
|
||||
print_success(f"Repository configured: {repo_type}:{repo_id}")
|
||||
console.print()
|
||||
|
||||
# Now ask if user wants to verify
|
||||
console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]")
|
||||
console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True):
|
||||
console.print()
|
||||
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Perform verification
|
||||
console.print()
|
||||
print_step("Verifying model integrity...")
|
||||
console.print()
|
||||
|
||||
# Check connectivity first
|
||||
use_mirror = False
|
||||
if user_model.repo_type == "huggingface":
|
||||
with console.status("[dim]Checking HuggingFace connectivity...[/dim]"):
|
||||
is_accessible, message = check_huggingface_connectivity(timeout=5)
|
||||
|
||||
if not is_accessible:
|
||||
print_warning("HuggingFace Connection Failed")
|
||||
console.print()
|
||||
console.print(f" {message}")
|
||||
console.print()
|
||||
console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]")
|
||||
console.print()
|
||||
use_mirror = True
|
||||
|
||||
# Fetch remote hashes with timeout
|
||||
def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):
|
||||
"""Fetch hashes with timeout."""
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)
|
||||
hashes = future.result(timeout=timeout)
|
||||
executor.shutdown(wait=False)
|
||||
return (hashes, False)
|
||||
except (FutureTimeoutError, Exception):
|
||||
executor.shutdown(wait=False)
|
||||
return (None, True)
|
||||
|
||||
# Try fetching hashes
|
||||
status = console.status("[dim]Fetching remote hashes...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)
|
||||
status.stop()
|
||||
|
||||
# Handle timeout with fallback
|
||||
if timed_out and user_model.repo_type == "huggingface" and not use_mirror:
|
||||
print_warning("HuggingFace Fetch Timeout (10s)")
|
||||
console.print()
|
||||
console.print(" [yellow]Trying HuggingFace mirror...[/yellow]")
|
||||
console.print()
|
||||
|
||||
status = console.status("[dim]Fetching remote hashes from mirror...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)
|
||||
status.stop()
|
||||
|
||||
if timed_out and user_model.repo_type == "huggingface":
|
||||
print_warning("HuggingFace Mirror Timeout (10s)")
|
||||
console.print()
|
||||
console.print(" [yellow]Fallback to ModelScope...[/yellow]")
|
||||
console.print()
|
||||
|
||||
status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10)
|
||||
status.stop()
|
||||
|
||||
if not official_hashes or timed_out:
|
||||
print_error("Failed to fetch remote hashes (network timeout)")
|
||||
console.print()
|
||||
console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"Continue {operation_name} without verification?", default=False):
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print()
|
||||
return
|
||||
|
||||
console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]")
|
||||
console.print()
|
||||
|
||||
# Calculate local hashes and compare
|
||||
local_dir = Path(user_model.path)
|
||||
files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()]
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
# Calculate local hashes
|
||||
task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash))
|
||||
|
||||
def hash_callback(msg):
|
||||
if "[" in msg and "/" in msg and "]" in msg and "✓" in msg:
|
||||
progress.advance(task)
|
||||
|
||||
local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback)
|
||||
progress.remove_task(task)
|
||||
|
||||
console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]")
|
||||
console.print()
|
||||
|
||||
# Compare hashes
|
||||
task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes))
|
||||
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for filename, official_hash in official_hashes.items():
|
||||
file_basename = Path(filename).name
|
||||
local_hash = None
|
||||
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
else:
|
||||
files_passed += 1
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
progress.remove_task(task)
|
||||
|
||||
console.print()
|
||||
|
||||
# Check results
|
||||
if not files_failed and not files_missing:
|
||||
# Verification passed
|
||||
user_registry.update_model(user_model.name, {"sha256_status": "passed"})
|
||||
print_success("Model integrity verification PASSED ✓")
|
||||
console.print()
|
||||
console.print(f" All {files_passed} files verified successfully")
|
||||
console.print()
|
||||
else:
|
||||
# Verification failed
|
||||
user_registry.update_model(user_model.name, {"sha256_status": "failed"})
|
||||
print_error(f"Model integrity verification FAILED")
|
||||
console.print()
|
||||
console.print(f" ✓ Passed: [green]{files_passed}[/green]")
|
||||
console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]")
|
||||
console.print()
|
||||
|
||||
if files_missing:
|
||||
console.print(f" [red]Missing files ({len(files_missing)}):[/red]")
|
||||
for f in files_missing[:5]:
|
||||
console.print(f" - {Path(f).name}")
|
||||
if len(files_missing) > 5:
|
||||
console.print(f" ... and {len(files_missing) - 5} more")
|
||||
console.print()
|
||||
|
||||
if files_failed:
|
||||
console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]")
|
||||
for f in files_failed[:5]:
|
||||
console.print(f" - {f}")
|
||||
if len(files_failed) > 5:
|
||||
console.print(f" ... and {len(files_failed) - 5} more")
|
||||
console.print()
|
||||
|
||||
console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]")
|
||||
console.print()
|
||||
console.print("This could cause runtime errors or incorrect inference results.")
|
||||
console.print()
|
||||
|
||||
# Ask if user wants to repair
|
||||
if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True):
|
||||
console.print()
|
||||
print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]")
|
||||
console.print()
|
||||
console.print("The verify command will guide you through the repair process.")
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Ask if user wants to continue anyway
|
||||
console.print()
|
||||
if not Confirm.ask(
|
||||
f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False
|
||||
):
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print()
|
||||
print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...")
|
||||
console.print()
|
||||
57
kt-kernel/python/cli/utils/port_checker.py
Normal file
57
kt-kernel/python/cli/utils/port_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Port availability checking utilities.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def is_port_available(host: str, port: int) -> bool:
|
||||
"""Check if a port is available on the given host.
|
||||
|
||||
Args:
|
||||
host: Host address (e.g., "0.0.0.0", "127.0.0.1")
|
||||
port: Port number to check
|
||||
|
||||
Returns:
|
||||
True if port is available, False if occupied
|
||||
"""
|
||||
try:
|
||||
# Try to bind to the port
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
|
||||
# Use SO_REUSEADDR to allow binding to recently closed ports
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
# Try to bind
|
||||
result = sock.connect_ex((host if host != "0.0.0.0" else "127.0.0.1", port))
|
||||
sock.close()
|
||||
|
||||
# If connect_ex returns 0, port is occupied
|
||||
# If it returns error (non-zero), port is available
|
||||
return result != 0
|
||||
|
||||
except Exception:
|
||||
# If any error occurs, assume port is not available
|
||||
return False
|
||||
|
||||
|
||||
def find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]:
|
||||
"""Find an available port starting from start_port.
|
||||
|
||||
Args:
|
||||
host: Host address
|
||||
start_port: Starting port number to check
|
||||
max_attempts: Maximum number of ports to try
|
||||
|
||||
Returns:
|
||||
Tuple of (found, port_number)
|
||||
- found: True if an available port was found
|
||||
- port_number: The available port number (or start_port if not found)
|
||||
"""
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if is_port_available(host, port):
|
||||
return True, port
|
||||
|
||||
return False, start_port
|
||||
347
kt-kernel/python/cli/utils/quant_interactive.py
Normal file
347
kt-kernel/python/cli/utils/quant_interactive.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Interactive configuration for kt quant command.
|
||||
|
||||
Provides rich, multi-step interactive configuration for model quantization.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm, IntPrompt
|
||||
from kt_kernel.cli.i18n import t
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def select_model_to_quantize() -> Optional[Any]:
|
||||
"""Select model to quantize interactively."""
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP
|
||||
from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table
|
||||
|
||||
registry = UserModelRegistry()
|
||||
all_models = registry.list_models()
|
||||
|
||||
# Filter MoE models only (safetensors, not AMX, is_moe=True)
|
||||
quant_models = []
|
||||
for model in all_models:
|
||||
if model.format == "safetensors":
|
||||
# Skip AMX models
|
||||
is_amx, _ = is_amx_weights(model.path)
|
||||
if is_amx:
|
||||
continue
|
||||
|
||||
# Only include MoE models
|
||||
if model.is_moe:
|
||||
quant_models.append(model)
|
||||
|
||||
if not quant_models:
|
||||
console.print(f"[yellow]{t('quant_no_moe_models')}[/yellow]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_only_moe')}")
|
||||
console.print()
|
||||
console.print(f" {t('quant_add_models', command='kt model scan')}")
|
||||
console.print(f" {t('quant_add_models', command='kt model add <path>')}")
|
||||
return None
|
||||
|
||||
# Display models
|
||||
console.print()
|
||||
console.print(f"[bold green]{t('quant_moe_available')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
# Use shared table builder
|
||||
table, displayed_models = build_moe_gpu_table(
|
||||
models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
choice = IntPrompt.ask(t("quant_select_model"), default=1, show_choices=False)
|
||||
|
||||
if choice < 1 or choice > len(displayed_models):
|
||||
console.print(f"[red]{t('quant_invalid_choice')}[/red]")
|
||||
return None
|
||||
|
||||
return displayed_models[choice - 1]
|
||||
|
||||
|
||||
def configure_quantization_method() -> Dict[str, str]:
|
||||
"""Select quantization method and input type."""
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step2_method')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
# Method selection
|
||||
console.print(f"[bold]{t('quant_method_label')}[/bold]")
|
||||
console.print(f" [cyan][1][/cyan] {t('quant_int4_desc')}")
|
||||
console.print(f" [cyan][2][/cyan] {t('quant_int8_desc')}")
|
||||
console.print()
|
||||
|
||||
method_choice = Prompt.ask(t("quant_select_method"), choices=["1", "2"], default="1")
|
||||
method = "int4" if method_choice == "1" else "int8"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]{t('quant_input_type_label')}[/bold]")
|
||||
console.print(f" [cyan][1][/cyan] {t('quant_fp8_desc')}")
|
||||
console.print(f" [cyan][2][/cyan] {t('quant_fp16_desc')}")
|
||||
console.print(f" [cyan][3][/cyan] {t('quant_bf16_desc')}")
|
||||
console.print()
|
||||
|
||||
input_choice = Prompt.ask(t("quant_select_input_type"), choices=["1", "2", "3"], default="1")
|
||||
input_type_map = {"1": "fp8", "2": "fp16", "3": "bf16"}
|
||||
input_type = input_type_map[input_choice]
|
||||
|
||||
return {"method": method, "input_type": input_type}
|
||||
|
||||
|
||||
def configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]:
|
||||
"""Configure CPU parameters."""
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
def clamp(value: int, min_val: int, max_val: int, default: int) -> int:
|
||||
"""Clamp value to range or return default if out of bounds."""
|
||||
if min_val <= value <= max_val:
|
||||
return max(min_val, min(value, max_val))
|
||||
return default
|
||||
|
||||
default_threads = int(max_cores * 0.8)
|
||||
cpu_threads = IntPrompt.ask(t("quant_cpu_threads_prompt", max=max_cores), default=default_threads)
|
||||
cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads)
|
||||
|
||||
numa_nodes = IntPrompt.ask(t("quant_numa_nodes_prompt", max=max_numa), default=max_numa)
|
||||
numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa)
|
||||
|
||||
# Ask about GPU usage
|
||||
console.print()
|
||||
console.print(f"[bold]{t('quant_use_gpu_label')}[/bold]")
|
||||
console.print(f" [dim]{t('quant_gpu_speedup')}[/dim]")
|
||||
console.print()
|
||||
use_gpu = Confirm.ask(t("quant_enable_gpu"), default=True)
|
||||
|
||||
return {"cpu_threads": cpu_threads, "numa_nodes": numa_nodes, "use_gpu": use_gpu}
|
||||
|
||||
|
||||
def configure_output_path(model: Any, method: str, numa_nodes: int) -> Path:
|
||||
"""Configure output path for quantized weights."""
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step4_output')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
# Generate default output path
|
||||
model_path = Path(model.path)
|
||||
method_upper = method.upper()
|
||||
settings = get_settings()
|
||||
|
||||
# Priority: paths.weights > paths.models[0] > model's parent directory
|
||||
weights_dir = settings.weights_dir
|
||||
if weights_dir and weights_dir.exists():
|
||||
# Use configured weights directory (highest priority)
|
||||
default_output = weights_dir / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
else:
|
||||
# Use first model storage path
|
||||
model_paths = settings.get_model_paths()
|
||||
if model_paths and model_paths[0].exists():
|
||||
default_output = model_paths[0] / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
else:
|
||||
# Fallback to model's parent directory
|
||||
default_output = model_path.parent / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
|
||||
console.print(f"[dim]{t('quant_default_path')}[/dim]", default_output)
|
||||
console.print()
|
||||
|
||||
use_default = Confirm.ask(t("quant_use_default"), default=True)
|
||||
|
||||
if use_default:
|
||||
return default_output
|
||||
|
||||
custom_path = Prompt.ask(t("quant_custom_path"), default=str(default_output))
|
||||
|
||||
return Path(custom_path)
|
||||
|
||||
|
||||
def calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate source model size and estimated quantized size.
|
||||
|
||||
Args:
|
||||
source_path: Path to source model
|
||||
input_type: Input type (fp8, fp16, bf16)
|
||||
quant_method: Quantization method (int4, int8)
|
||||
|
||||
Returns:
|
||||
Tuple of (source_size_gb, estimated_quant_size_gb)
|
||||
"""
|
||||
# Calculate source model size
|
||||
try:
|
||||
total_bytes = sum(f.stat().st_size for f in source_path.glob("*.safetensors") if f.is_file())
|
||||
source_size_gb = total_bytes / (1024**3)
|
||||
except Exception:
|
||||
return 0.0, 0.0
|
||||
|
||||
# Bits mapping
|
||||
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
|
||||
quant_bits = {"int4": 4, "int8": 8}
|
||||
|
||||
input_bit = input_bits.get(input_type, 16)
|
||||
quant_bit = quant_bits.get(quant_method, 4)
|
||||
|
||||
# Estimate: source_size * (quant_bits / input_bits)
|
||||
ratio = quant_bit / input_bit
|
||||
estimated_size_gb = source_size_gb * ratio
|
||||
|
||||
return source_size_gb, estimated_size_gb
|
||||
|
||||
|
||||
def check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]:
|
||||
"""
|
||||
Check available disk space at output path.
|
||||
|
||||
Args:
|
||||
output_path: Target output path
|
||||
required_size_gb: Required space in GB
|
||||
|
||||
Returns:
|
||||
Tuple of (available_gb, is_sufficient)
|
||||
is_sufficient is True if available >= required * 1.2
|
||||
"""
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# Get parent directory that exists
|
||||
check_path = output_path.parent if not output_path.exists() else output_path
|
||||
while not check_path.exists() and check_path != check_path.parent:
|
||||
check_path = check_path.parent
|
||||
|
||||
stat = shutil.disk_usage(check_path)
|
||||
available_gb = stat.free / (1024**3)
|
||||
|
||||
# Check if available space >= required * 1.2 (20% buffer)
|
||||
is_sufficient = available_gb >= (required_size_gb * 1.2)
|
||||
|
||||
return available_gb, is_sufficient
|
||||
except Exception:
|
||||
return 0.0, False
|
||||
|
||||
|
||||
def interactive_quant_config() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Interactive configuration for kt quant.
|
||||
|
||||
Returns configuration dict or None if cancelled.
|
||||
"""
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info
|
||||
|
||||
# Get CPU info
|
||||
cpu_info = detect_cpu_info()
|
||||
|
||||
# Step 1: Select model
|
||||
model = select_model_to_quantize()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
# Step 1.5: Pre-quantization verification (optional)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(model.path)
|
||||
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
|
||||
|
||||
# Step 2: Configure quantization method
|
||||
quant_config = configure_quantization_method()
|
||||
|
||||
# Step 3: Configure CPU parameters
|
||||
cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes) # Use logical threads
|
||||
|
||||
# Step 4: Configure output path
|
||||
output_path = configure_output_path(model, quant_config["method"], cpu_config["numa_nodes"])
|
||||
|
||||
# Step 4.5: Check if output path already exists and generate unique name
|
||||
if output_path.exists():
|
||||
console.print()
|
||||
console.print(t("quant_output_exists_warn", path=str(output_path)))
|
||||
console.print()
|
||||
|
||||
# Generate unique name by adding suffix
|
||||
original_name = output_path.name
|
||||
parent_dir = output_path.parent
|
||||
counter = 2
|
||||
|
||||
while output_path.exists():
|
||||
new_name = f"{original_name}-{counter}"
|
||||
output_path = parent_dir / new_name
|
||||
counter += 1
|
||||
|
||||
console.print(t("quant_using_unique_name", path=str(output_path)))
|
||||
console.print()
|
||||
|
||||
# Step 5: Calculate space requirements and check availability
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
source_size_gb, estimated_size_gb = calculate_quantized_size(
|
||||
Path(model.path), quant_config["input_type"], quant_config["method"]
|
||||
)
|
||||
|
||||
available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb)
|
||||
|
||||
console.print(f" {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]")
|
||||
console.print(f" {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]")
|
||||
console.print(
|
||||
f" {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]"
|
||||
)
|
||||
console.print()
|
||||
|
||||
if not is_sufficient:
|
||||
required_with_buffer = estimated_size_gb * 1.2
|
||||
console.print(f"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]")
|
||||
console.print(f" {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]")
|
||||
console.print(f" {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_may_fail')}")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"[yellow]{t('quant_continue_anyway')}[/yellow]", default=False):
|
||||
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
|
||||
return None
|
||||
console.print()
|
||||
|
||||
# Summary and confirmation
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_config_summary')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
console.print(f" {t('quant_summary_model'):<15} {model.name}")
|
||||
console.print(f" {t('quant_summary_method'):<15} {quant_config['method'].upper()}")
|
||||
console.print(f" {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}")
|
||||
console.print(f" {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}")
|
||||
console.print(f" {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}")
|
||||
console.print(f" {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}")
|
||||
console.print(f" {t('quant_summary_output'):<15} {output_path}")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"[bold green]{t('quant_start_question')}[/bold green]", default=True):
|
||||
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
|
||||
return None
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"method": quant_config["method"],
|
||||
"input_type": quant_config["input_type"],
|
||||
"cpu_threads": cpu_config["cpu_threads"],
|
||||
"numa_nodes": cpu_config["numa_nodes"],
|
||||
"use_gpu": cpu_config["use_gpu"],
|
||||
"output_path": output_path,
|
||||
}
|
||||
364
kt-kernel/python/cli/utils/repo_detector.py
Normal file
364
kt-kernel/python/cli/utils/repo_detector.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Repo Detector
|
||||
|
||||
Automatically detect repository information from model README.md files
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Tuple
|
||||
import yaml
|
||||
|
||||
|
||||
def parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]:
|
||||
"""
|
||||
Parse YAML frontmatter from README.md
|
||||
|
||||
Args:
|
||||
readme_path: Path to README.md file
|
||||
|
||||
Returns:
|
||||
Dictionary of frontmatter data, or None if not found
|
||||
"""
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Match YAML frontmatter between --- markers
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
yaml_content = match.group(1)
|
||||
|
||||
# Parse YAML
|
||||
try:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
return data if isinstance(data, dict) else None
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo_id and repo_type from frontmatter
|
||||
|
||||
Args:
|
||||
frontmatter: Parsed YAML frontmatter dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None
|
||||
repo_type is either "huggingface" or "modelscope"
|
||||
"""
|
||||
if not frontmatter:
|
||||
return None
|
||||
|
||||
# Priority 1: Extract from license_link (most reliable)
|
||||
license_link = frontmatter.get("license_link")
|
||||
if license_link and isinstance(license_link, str):
|
||||
result = _extract_repo_from_url(license_link)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Priority 2: Try to find repo_id from other fields
|
||||
repo_id = None
|
||||
|
||||
# Check base_model field
|
||||
base_model = frontmatter.get("base_model")
|
||||
if base_model:
|
||||
if isinstance(base_model, list) and len(base_model) > 0:
|
||||
# base_model is a list, take first item
|
||||
repo_id = base_model[0]
|
||||
elif isinstance(base_model, str):
|
||||
repo_id = base_model
|
||||
|
||||
# Check model-index field
|
||||
if not repo_id:
|
||||
model_index = frontmatter.get("model-index")
|
||||
if isinstance(model_index, list) and len(model_index) > 0:
|
||||
first_model = model_index[0]
|
||||
if isinstance(first_model, dict):
|
||||
repo_id = first_model.get("name")
|
||||
|
||||
# Check model_name field
|
||||
if not repo_id:
|
||||
repo_id = frontmatter.get("model_name")
|
||||
|
||||
if not repo_id or not isinstance(repo_id, str):
|
||||
return None
|
||||
|
||||
# Validate format: should be "namespace/model-name"
|
||||
if "/" not in repo_id:
|
||||
return None
|
||||
|
||||
parts = repo_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
# Determine repo type
|
||||
repo_type = "huggingface" # Default
|
||||
|
||||
# Look for ModelScope indicators
|
||||
if "modelscope" in repo_id.lower():
|
||||
repo_type = "modelscope"
|
||||
|
||||
# Check tags
|
||||
tags = frontmatter.get("tags", [])
|
||||
if isinstance(tags, list):
|
||||
if "modelscope" in [str(t).lower() for t in tags]:
|
||||
repo_type = "modelscope"
|
||||
|
||||
return (repo_id, repo_type)
|
||||
|
||||
|
||||
def _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo_id and repo_type from a URL
|
||||
|
||||
Supports:
|
||||
- https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE
|
||||
- https://modelscope.cn/models/Qwen/Qwen3-30B-A3B
|
||||
|
||||
Args:
|
||||
url: URL string
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None
|
||||
"""
|
||||
# HuggingFace pattern: https://huggingface.co/{namespace}/{model}/...
|
||||
hf_match = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", url)
|
||||
if hf_match:
|
||||
namespace = hf_match.group(1)
|
||||
model_name = hf_match.group(2)
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
return (repo_id, "huggingface")
|
||||
|
||||
# ModelScope pattern: https://modelscope.cn/models/{namespace}/{model}
|
||||
ms_match = re.match(r"https?://(?:www\.)?modelscope\.cn/models/([^/]+)/([^/]+)", url)
|
||||
if ms_match:
|
||||
namespace = ms_match.group(1)
|
||||
model_name = ms_match.group(2)
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
return (repo_id, "modelscope")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo info by globally searching for URLs in README.md
|
||||
|
||||
Args:
|
||||
readme_path: Path to README.md file
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None if not found
|
||||
"""
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all HuggingFace URLs
|
||||
hf_pattern = r"https?://huggingface\.co/([^/\s]+)/([^/\s\)]+)"
|
||||
hf_matches = re.findall(hf_pattern, content)
|
||||
|
||||
# Find all ModelScope URLs
|
||||
ms_pattern = r"https?://(?:www\.)?modelscope\.cn/models/([^/\s]+)/([^/\s\)]+)"
|
||||
ms_matches = re.findall(ms_pattern, content)
|
||||
|
||||
# Collect all found repos with their types
|
||||
found_repos = []
|
||||
|
||||
for namespace, model_name in hf_matches:
|
||||
# Skip common non-repo paths
|
||||
if namespace.lower() in ["docs", "blog", "spaces", "datasets"]:
|
||||
continue
|
||||
if model_name.lower() in ["tree", "blob", "raw", "resolve", "discussions"]:
|
||||
continue
|
||||
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
found_repos.append((repo_id, "huggingface"))
|
||||
|
||||
for namespace, model_name in ms_matches:
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
found_repos.append((repo_id, "modelscope"))
|
||||
|
||||
if not found_repos:
|
||||
return None
|
||||
|
||||
# If multiple different repos found, use the last one
|
||||
# First, deduplicate
|
||||
seen = {}
|
||||
for repo_id, repo_type in found_repos:
|
||||
seen[repo_id] = repo_type # Will keep the last occurrence
|
||||
|
||||
# Get the last unique repo
|
||||
if seen:
|
||||
# Use the last item from found_repos that's unique
|
||||
last_unique = None
|
||||
for repo_id, repo_type in found_repos:
|
||||
if repo_id in seen:
|
||||
last_unique = (repo_id, repo_type)
|
||||
|
||||
return last_unique
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Detect repository information for a model
|
||||
|
||||
Strategy:
|
||||
Only extract from YAML frontmatter metadata in README.md
|
||||
(Removed global URL search to avoid false positives)
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None if not detected
|
||||
"""
|
||||
model_dir = Path(model_path)
|
||||
|
||||
if not model_dir.exists() or not model_dir.is_dir():
|
||||
return None
|
||||
|
||||
# Look for README.md
|
||||
readme_path = model_dir / "README.md"
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
# Only parse YAML frontmatter (no fallback to global search)
|
||||
frontmatter = parse_readme_frontmatter(readme_path)
|
||||
if frontmatter:
|
||||
return extract_repo_from_frontmatter(frontmatter)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def scan_models_for_repo(model_list) -> Dict:
|
||||
"""
|
||||
Scan a list of models and detect repo information
|
||||
|
||||
Args:
|
||||
model_list: List of UserModel objects
|
||||
|
||||
Returns:
|
||||
Dictionary with scan results:
|
||||
{
|
||||
'detected': [(model, repo_id, repo_type), ...],
|
||||
'not_detected': [model, ...],
|
||||
'skipped': [model, ...] # Already has repo_id
|
||||
}
|
||||
"""
|
||||
results = {"detected": [], "not_detected": [], "skipped": []}
|
||||
|
||||
for model in model_list:
|
||||
# Skip if already has repo_id
|
||||
if model.repo_id:
|
||||
results["skipped"].append(model)
|
||||
continue
|
||||
|
||||
# Only process safetensors and gguf models
|
||||
if model.format not in ["safetensors", "gguf"]:
|
||||
results["skipped"].append(model)
|
||||
continue
|
||||
|
||||
# Try to detect repo
|
||||
repo_info = detect_repo_for_model(model.path)
|
||||
|
||||
if repo_info:
|
||||
repo_id, repo_type = repo_info
|
||||
results["detected"].append((model, repo_id, repo_type))
|
||||
else:
|
||||
results["not_detected"].append(model)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_detection_report(results: Dict) -> str:
|
||||
"""
|
||||
Format scan results into a readable report
|
||||
|
||||
Args:
|
||||
results: Results from scan_models_for_repo()
|
||||
|
||||
Returns:
|
||||
Formatted string report
|
||||
"""
|
||||
lines = []
|
||||
|
||||
lines.append("=" * 80)
|
||||
lines.append("Auto-Detection Report")
|
||||
lines.append("=" * 80)
|
||||
lines.append("")
|
||||
|
||||
# Detected
|
||||
if results["detected"]:
|
||||
lines.append(f"✓ Detected repository information ({len(results['detected'])} models):")
|
||||
lines.append("")
|
||||
for model, repo_id, repo_type in results["detected"]:
|
||||
lines.append(f" • {model.name}")
|
||||
lines.append(f" Path: {model.path}")
|
||||
lines.append(f" Repo: {repo_id} ({repo_type})")
|
||||
lines.append("")
|
||||
|
||||
# Not detected
|
||||
if results["not_detected"]:
|
||||
lines.append(f"✗ No repository information found ({len(results['not_detected'])} models):")
|
||||
lines.append("")
|
||||
for model in results["not_detected"]:
|
||||
lines.append(f" • {model.name}")
|
||||
lines.append(f" Path: {model.path}")
|
||||
lines.append("")
|
||||
|
||||
# Skipped
|
||||
if results["skipped"]:
|
||||
lines.append(f"⊘ Skipped ({len(results['skipped'])} models):")
|
||||
lines.append(f" (Already have repo_id or not safetensors/gguf format)")
|
||||
lines.append("")
|
||||
|
||||
lines.append("=" * 80)
|
||||
lines.append(
|
||||
f"Summary: {len(results['detected'])} detected, "
|
||||
f"{len(results['not_detected'])} not detected, "
|
||||
f"{len(results['skipped'])} skipped"
|
||||
)
|
||||
lines.append("=" * 80)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def apply_detection_results(results: Dict, registry) -> int:
|
||||
"""
|
||||
Apply detected repo information to models in registry
|
||||
|
||||
Args:
|
||||
results: Results from scan_models_for_repo()
|
||||
registry: UserModelRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of models updated
|
||||
"""
|
||||
updated_count = 0
|
||||
|
||||
for model, repo_id, repo_type in results["detected"]:
|
||||
success = registry.update_model(model.name, {"repo_id": repo_id, "repo_type": repo_type})
|
||||
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
return updated_count
|
||||
111
kt-kernel/python/cli/utils/run_configs.py
Normal file
111
kt-kernel/python/cli/utils/run_configs.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Configuration save/load for kt run command.
|
||||
|
||||
Manages saved run configurations bound to specific models.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import yaml
|
||||
|
||||
|
||||
CONFIG_FILE = Path.home() / ".ktransformers" / "run_configs.yaml"
|
||||
|
||||
|
||||
class RunConfigManager:
|
||||
"""Manager for saved run configurations."""
|
||||
|
||||
def __init__(self):
|
||||
self.config_file = CONFIG_FILE
|
||||
self._ensure_config_file()
|
||||
|
||||
def _ensure_config_file(self):
|
||||
"""Ensure config file exists."""
|
||||
if not self.config_file.exists():
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_data({"version": "1.0", "configs": {}})
|
||||
|
||||
def _load_data(self) -> Dict:
|
||||
"""Load raw config data."""
|
||||
try:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {"version": "1.0", "configs": {}}
|
||||
except Exception:
|
||||
return {"version": "1.0", "configs": {}}
|
||||
|
||||
def _save_data(self, data: Dict):
|
||||
"""Save raw config data."""
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
def list_configs(self, model_id: str) -> List[Dict[str, Any]]:
|
||||
"""List all saved configs for a model.
|
||||
|
||||
Returns:
|
||||
List of config dicts with 'config_name' and other fields.
|
||||
"""
|
||||
data = self._load_data()
|
||||
configs = data.get("configs", {}).get(model_id, [])
|
||||
return configs if isinstance(configs, list) else []
|
||||
|
||||
def save_config(self, model_id: str, config: Dict[str, Any]):
|
||||
"""Save a configuration for a model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to bind config to
|
||||
config: Configuration dict with all run parameters
|
||||
"""
|
||||
data = self._load_data()
|
||||
|
||||
if "configs" not in data:
|
||||
data["configs"] = {}
|
||||
|
||||
if model_id not in data["configs"]:
|
||||
data["configs"][model_id] = []
|
||||
|
||||
# Add timestamp
|
||||
config["created_at"] = datetime.now().isoformat()
|
||||
|
||||
# Append config
|
||||
data["configs"][model_id].append(config)
|
||||
|
||||
self._save_data(data)
|
||||
|
||||
def delete_config(self, model_id: str, config_index: int) -> bool:
|
||||
"""Delete a saved configuration.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
config_index: Index of config to delete (0-based)
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
data = self._load_data()
|
||||
|
||||
if model_id not in data.get("configs", {}):
|
||||
return False
|
||||
|
||||
configs = data["configs"][model_id]
|
||||
if config_index < 0 or config_index >= len(configs):
|
||||
return False
|
||||
|
||||
configs.pop(config_index)
|
||||
self._save_data(data)
|
||||
return True
|
||||
|
||||
def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific saved configuration.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
config_index: Index of config to get (0-based)
|
||||
|
||||
Returns:
|
||||
Config dict or None if not found
|
||||
"""
|
||||
configs = self.list_configs(model_id)
|
||||
if config_index < 0 or config_index >= len(configs):
|
||||
return None
|
||||
return configs[config_index]
|
||||
1084
kt-kernel/python/cli/utils/run_interactive.py
Normal file
1084
kt-kernel/python/cli/utils/run_interactive.py
Normal file
File diff suppressed because it is too large
Load Diff
459
kt-kernel/python/cli/utils/tuna_engine.py
Normal file
459
kt-kernel/python/cli/utils/tuna_engine.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Tuna engine for auto-tuning GPU experts configuration.
|
||||
|
||||
Automatically finds the maximum viable num-gpu-experts through binary search
|
||||
by testing actual server launches with different configurations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kt_kernel.cli.utils.console import console, print_error, print_info, print_warning
|
||||
|
||||
|
||||
def get_num_experts(model_path: Path) -> int:
|
||||
"""
|
||||
Get the number of experts per layer from model config.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model directory
|
||||
|
||||
Returns:
|
||||
Number of experts per layer
|
||||
|
||||
Raises:
|
||||
ValueError: If config.json not found or num_experts field missing
|
||||
"""
|
||||
config_file = model_path / "config.json"
|
||||
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"config.json not found in {model_path}")
|
||||
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse config.json: {e}")
|
||||
|
||||
# Different models may use different field names
|
||||
possible_keys = [
|
||||
"num_experts_per_tok", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
"n_routed_experts", # Qwen
|
||||
"num_experts", # Generic
|
||||
]
|
||||
|
||||
for key in possible_keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
|
||||
raise ValueError(f"Cannot find num_experts field in {config_file}. " f"Tried: {', '.join(possible_keys)}")
|
||||
|
||||
|
||||
def detect_oom(log_line: Optional[str]) -> bool:
|
||||
"""
|
||||
Detect OOM (Out Of Memory) errors from log output.
|
||||
|
||||
Args:
|
||||
log_line: A line from server output
|
||||
|
||||
Returns:
|
||||
True if OOM detected, False otherwise
|
||||
"""
|
||||
if log_line is None:
|
||||
return False
|
||||
|
||||
log_lower = log_line.lower()
|
||||
|
||||
oom_patterns = [
|
||||
"cuda out of memory",
|
||||
"out of memory",
|
||||
"outofmemoryerror",
|
||||
"oom",
|
||||
"failed to allocate",
|
||||
"cumemalloc failed",
|
||||
"cumemallocasync failed",
|
||||
"allocation failed",
|
||||
]
|
||||
|
||||
return any(pattern in log_lower for pattern in oom_patterns)
|
||||
|
||||
|
||||
def test_config(
|
||||
num_gpu_experts: int,
|
||||
model_path: Path,
|
||||
config: dict,
|
||||
verbose: bool = False,
|
||||
) -> tuple[bool, float]:
|
||||
"""
|
||||
Test if a configuration with given num_gpu_experts works.
|
||||
|
||||
Args:
|
||||
num_gpu_experts: Number of GPU experts to test
|
||||
model_path: Path to the model
|
||||
config: Configuration dict with all parameters
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
(success: bool, elapsed_time: float)
|
||||
- success: True if server starts and inference works
|
||||
- elapsed_time: Time taken for the test
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use random port to avoid conflicts
|
||||
test_port = random.randint(30000, 40000)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model",
|
||||
str(model_path),
|
||||
"--port",
|
||||
str(test_port),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--tensor-parallel-size",
|
||||
str(config["tensor_parallel_size"]),
|
||||
"--kt-num-gpu-experts",
|
||||
str(num_gpu_experts),
|
||||
"--max-total-tokens",
|
||||
str(config["max_total_tokens"]),
|
||||
]
|
||||
|
||||
# Add kt-kernel options
|
||||
if config.get("weights_path"):
|
||||
cmd.extend(["--kt-weight-path", str(config["weights_path"])])
|
||||
else:
|
||||
cmd.extend(["--kt-weight-path", str(model_path)])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--kt-cpuinfer",
|
||||
str(config.get("cpu_threads", 64)),
|
||||
"--kt-threadpool-count",
|
||||
str(config.get("numa_nodes", 2)),
|
||||
"--kt-method",
|
||||
config.get("kt_method", "AMXINT4"),
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(config.get("kt_gpu_prefill_threshold", 4096)),
|
||||
]
|
||||
)
|
||||
|
||||
# Add other SGLang options
|
||||
if config.get("attention_backend"):
|
||||
cmd.extend(["--attention-backend", config["attention_backend"]])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
str(config.get("mem_fraction_static", 0.98)),
|
||||
"--chunked-prefill-size",
|
||||
str(config.get("chunked_prefill_size", 4096)),
|
||||
"--max-running-requests",
|
||||
str(config.get("max_running_requests", 1)), # Use 1 for faster testing
|
||||
"--watchdog-timeout",
|
||||
str(config.get("watchdog_timeout", 3000)),
|
||||
"--enable-mixed-chunk",
|
||||
"--enable-p2p-check",
|
||||
]
|
||||
)
|
||||
|
||||
# Add disable-shared-experts-fusion if specified
|
||||
if config.get("disable_shared_experts_fusion"):
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add extra args
|
||||
if config.get("extra_args"):
|
||||
cmd.extend(config["extra_args"])
|
||||
|
||||
if verbose:
|
||||
console.print(f"[dim]Command: {' '.join(cmd)}[/dim]")
|
||||
|
||||
# Start process
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=config.get("env"),
|
||||
)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_error(f"Failed to start process: {e}")
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Monitor process output
|
||||
timeout = 60 # Maximum 60 seconds to wait
|
||||
server_ready = False
|
||||
|
||||
try:
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if process has output
|
||||
if process.poll() is not None:
|
||||
# Process exited
|
||||
if verbose:
|
||||
print_warning("Process exited early")
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Read output line (non-blocking)
|
||||
try:
|
||||
line = process.stdout.readline()
|
||||
if not line:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if verbose:
|
||||
console.print(f"[dim]{line.rstrip()}[/dim]")
|
||||
|
||||
# Fast OOM detection
|
||||
if detect_oom(line):
|
||||
if verbose:
|
||||
print_warning(f"OOM detected: {line.rstrip()}")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Check for startup success
|
||||
if "Uvicorn running" in line or "Application startup complete" in line:
|
||||
server_ready = True
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_warning(f"Error reading output: {e}")
|
||||
break
|
||||
|
||||
if not server_ready:
|
||||
# Timeout or failed to start
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Server is ready, test inference
|
||||
success = test_inference(test_port, verbose=verbose)
|
||||
|
||||
# Cleanup
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
return success, time.time() - start_time
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# User cancelled
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
raise
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_error(f"Test failed with exception: {e}")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=2)
|
||||
except:
|
||||
try:
|
||||
process.kill()
|
||||
except:
|
||||
pass
|
||||
return False, time.time() - start_time
|
||||
|
||||
|
||||
def test_inference(port: int, verbose: bool = False) -> bool:
|
||||
"""
|
||||
Test if the server can handle a simple inference request.
|
||||
|
||||
Args:
|
||||
port: Server port
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
True if inference succeeds, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Wait a bit for server to be fully ready
|
||||
time.sleep(2)
|
||||
|
||||
# Try to import OpenAI client
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
if verbose:
|
||||
print_warning("OpenAI package not available, skipping inference test")
|
||||
return True # Assume success if we can't test
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f"http://127.0.0.1:{port}/v1",
|
||||
api_key="test",
|
||||
)
|
||||
|
||||
# Send a simple test request
|
||||
response = client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=1,
|
||||
temperature=0,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Check if we got a valid response
|
||||
success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None
|
||||
|
||||
if verbose:
|
||||
if success:
|
||||
print_info(f"Inference test passed: {response.choices[0].message.content}")
|
||||
else:
|
||||
print_warning("Inference test failed: no valid response")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_warning(f"Inference test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def find_max_gpu_experts(
|
||||
model_path: Path,
|
||||
config: dict,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Binary search to find the maximum viable num_gpu_experts.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
config: Configuration dict
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
Maximum number of GPU experts that works
|
||||
"""
|
||||
# Get number of experts from model config
|
||||
try:
|
||||
num_experts = get_num_experts(model_path)
|
||||
except ValueError as e:
|
||||
print_error(str(e))
|
||||
raise
|
||||
|
||||
console.print()
|
||||
console.print(f"Binary search range: [0, {num_experts}]")
|
||||
console.print()
|
||||
|
||||
left, right = 0, num_experts
|
||||
result = 0
|
||||
iteration = 0
|
||||
total_iterations = math.ceil(math.log2(num_experts + 1))
|
||||
|
||||
while left <= right:
|
||||
iteration += 1
|
||||
mid = (left + right) // 2
|
||||
|
||||
console.print(f"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... ", end="")
|
||||
|
||||
success, elapsed = test_config(mid, model_path, config, verbose=verbose)
|
||||
|
||||
if success:
|
||||
console.print(f"[green]✓ OK[/green] ({elapsed:.1f}s)")
|
||||
result = mid
|
||||
left = mid + 1
|
||||
else:
|
||||
console.print(f"[red]✗ FAILED[/red] ({elapsed:.1f}s)")
|
||||
right = mid - 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_tuna(
|
||||
model_path: Path,
|
||||
tensor_parallel_size: int,
|
||||
max_total_tokens: int,
|
||||
kt_method: str,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""
|
||||
Run tuna auto-tuning to find optimal num_gpu_experts.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
tensor_parallel_size: Tensor parallel size
|
||||
max_total_tokens: Maximum total tokens
|
||||
kt_method: KT quantization method
|
||||
verbose: Whether to show detailed logs
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Optimal num_gpu_experts value
|
||||
|
||||
Raises:
|
||||
ValueError: If tuning fails completely
|
||||
"""
|
||||
# Prepare configuration
|
||||
config = {
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"max_total_tokens": max_total_tokens,
|
||||
"kt_method": kt_method,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Run binary search
|
||||
try:
|
||||
result = find_max_gpu_experts(model_path, config, verbose=verbose)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
print_warning("Tuning cancelled by user")
|
||||
raise
|
||||
|
||||
console.print()
|
||||
|
||||
# Check if even 0 doesn't work
|
||||
if result == 0:
|
||||
console.print("[yellow]Testing if gpu-experts=0 is viable...[/yellow]")
|
||||
success, _ = test_config(0, model_path, config, verbose=verbose)
|
||||
|
||||
if not success:
|
||||
# Even 0 doesn't work
|
||||
console.print()
|
||||
print_error("Failed to start server even with all experts on CPU (gpu-experts=0)")
|
||||
console.print()
|
||||
console.print("[bold]Possible reasons:[/bold]")
|
||||
console.print(" • Insufficient GPU memory for base model layers")
|
||||
console.print(" • max-total-tokens is too large for available VRAM")
|
||||
console.print(" • Tensor parallel configuration issue")
|
||||
console.print()
|
||||
console.print("[bold]Suggestions:[/bold]")
|
||||
console.print(f" • Reduce --max-total-tokens (current: {max_total_tokens})")
|
||||
console.print(f" • Reduce --tensor-parallel-size (current: {tensor_parallel_size})")
|
||||
console.print(" • Use more GPUs or GPUs with more VRAM")
|
||||
console.print(" • Try a smaller model")
|
||||
console.print()
|
||||
raise ValueError("Minimum GPU memory requirements not met")
|
||||
else:
|
||||
# 0 works but nothing more
|
||||
console.print()
|
||||
print_warning("All experts will run on CPU (gpu-experts=0). " "Performance will be limited by CPU speed.")
|
||||
|
||||
return result
|
||||
302
kt-kernel/python/cli/utils/user_model_registry.py
Normal file
302
kt-kernel/python/cli/utils/user_model_registry.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
User Model Registry
|
||||
|
||||
Manages user-registered models in ~/.ktransformers/user_models.yaml
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
import yaml
|
||||
|
||||
|
||||
# Constants
|
||||
USER_MODELS_FILE = Path.home() / ".ktransformers" / "user_models.yaml"
|
||||
REGISTRY_VERSION = "1.0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserModel:
|
||||
"""Represents a user-registered model"""
|
||||
|
||||
name: str # User-editable name (default: folder name)
|
||||
path: str # Absolute path to model directory
|
||||
format: str # "safetensors" | "gguf"
|
||||
id: Optional[str] = None # Unique UUID for this model (auto-generated if None)
|
||||
repo_type: Optional[str] = None # "huggingface" | "modelscope" | None
|
||||
repo_id: Optional[str] = None # e.g., "deepseek-ai/DeepSeek-V3"
|
||||
sha256_status: str = "not_checked" # "not_checked" | "checking" | "passed" | "failed" | "no_repo"
|
||||
gpu_model_ids: Optional[List[str]] = None # For llamafile/AMX: list of GPU model UUIDs to run with
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
last_verified: Optional[str] = None # ISO format datetime
|
||||
# MoE information (cached from analyze_moe_model)
|
||||
is_moe: Optional[bool] = None # True if MoE model, False if non-MoE, None if not analyzed
|
||||
moe_num_experts: Optional[int] = None # Total number of experts (for MoE models)
|
||||
moe_num_experts_per_tok: Optional[int] = None # Number of active experts per token (for MoE models)
|
||||
# AMX quantization metadata (for format == "amx")
|
||||
amx_source_model: Optional[str] = None # Name of the source MoE model that was quantized
|
||||
amx_quant_method: Optional[str] = None # "int4" | "int8"
|
||||
amx_numa_nodes: Optional[int] = None # Number of NUMA nodes used for quantization
|
||||
|
||||
def __post_init__(self):
|
||||
"""Ensure ID is set after initialization"""
|
||||
if self.id is None:
|
||||
import uuid
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for YAML serialization"""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "UserModel":
|
||||
"""Create from dictionary loaded from YAML"""
|
||||
return cls(**data)
|
||||
|
||||
def path_exists(self) -> bool:
|
||||
"""Check if model path still exists"""
|
||||
return Path(self.path).exists()
|
||||
|
||||
|
||||
class UserModelRegistry:
|
||||
"""Manages the user model registry"""
|
||||
|
||||
def __init__(self, registry_file: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the registry
|
||||
|
||||
Args:
|
||||
registry_file: Path to the registry YAML file (default: USER_MODELS_FILE)
|
||||
"""
|
||||
self.registry_file = registry_file or USER_MODELS_FILE
|
||||
self.models: List[UserModel] = []
|
||||
self.version = REGISTRY_VERSION
|
||||
|
||||
# Ensure directory exists
|
||||
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing registry
|
||||
self.load()
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load models from YAML file"""
|
||||
if not self.registry_file.exists():
|
||||
# Initialize empty registry
|
||||
self.models = []
|
||||
self.save() # Create the file
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.registry_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if not data:
|
||||
self.models = []
|
||||
return
|
||||
|
||||
# Load version
|
||||
self.version = data.get("version", REGISTRY_VERSION)
|
||||
|
||||
# Load models
|
||||
models_data = data.get("models", [])
|
||||
self.models = [UserModel.from_dict(m) for m in models_data]
|
||||
|
||||
# Migrate: ensure all models have UUIDs (for backward compatibility)
|
||||
needs_save = False
|
||||
for model in self.models:
|
||||
if model.id is None:
|
||||
import uuid
|
||||
|
||||
model.id = str(uuid.uuid4())
|
||||
needs_save = True
|
||||
|
||||
if needs_save:
|
||||
self.save()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load user model registry: {e}")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save models to YAML file"""
|
||||
data = {"version": self.version, "models": [m.to_dict() for m in self.models]}
|
||||
|
||||
try:
|
||||
with open(self.registry_file, "w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save user model registry: {e}")
|
||||
|
||||
def add_model(self, model: UserModel) -> None:
|
||||
"""
|
||||
Add a model to the registry
|
||||
|
||||
Args:
|
||||
model: UserModel instance to add
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same name already exists
|
||||
"""
|
||||
if self.check_name_conflict(model.name):
|
||||
raise ValueError(f"Model with name '{model.name}' already exists")
|
||||
|
||||
self.models.append(model)
|
||||
self.save()
|
||||
|
||||
def remove_model(self, name: str) -> bool:
|
||||
"""
|
||||
Remove a model from the registry
|
||||
|
||||
Args:
|
||||
name: Name of the model to remove
|
||||
|
||||
Returns:
|
||||
True if model was removed, False if not found
|
||||
"""
|
||||
original_count = len(self.models)
|
||||
self.models = [m for m in self.models if m.name != name]
|
||||
|
||||
if len(self.models) < original_count:
|
||||
self.save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_model(self, name: str, updates: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a model's attributes
|
||||
|
||||
Args:
|
||||
name: Name of the model to update
|
||||
updates: Dictionary of attributes to update
|
||||
|
||||
Returns:
|
||||
True if model was updated, False if not found
|
||||
"""
|
||||
model = self.get_model(name)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
# Update attributes
|
||||
for key, value in updates.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
self.save()
|
||||
return True
|
||||
|
||||
def get_model(self, name: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get a model by name
|
||||
|
||||
Args:
|
||||
name: Name of the model
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.name == name:
|
||||
return model
|
||||
return None
|
||||
|
||||
def get_model_by_id(self, model_id: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get a model by its unique ID
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.id == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def list_models(self) -> List[UserModel]:
|
||||
"""
|
||||
List all models
|
||||
|
||||
Returns:
|
||||
List of all UserModel instances
|
||||
"""
|
||||
return self.models.copy()
|
||||
|
||||
def find_by_path(self, path: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Find a model by its path
|
||||
|
||||
Args:
|
||||
path: Model directory path
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
# Normalize paths for comparison
|
||||
search_path = str(Path(path).resolve())
|
||||
|
||||
for model in self.models:
|
||||
model_path = str(Path(model.path).resolve())
|
||||
if model_path == search_path:
|
||||
return model
|
||||
return None
|
||||
|
||||
def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if a name conflicts with existing models
|
||||
|
||||
Args:
|
||||
name: Name to check
|
||||
exclude_name: Optional name to exclude from check (for rename operations)
|
||||
|
||||
Returns:
|
||||
True if conflict exists, False otherwise
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.name == name and model.name != exclude_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def refresh_status(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Check all models and identify missing ones
|
||||
|
||||
Returns:
|
||||
Dictionary with 'valid' and 'missing' lists of model names
|
||||
"""
|
||||
valid = []
|
||||
missing = []
|
||||
|
||||
for model in self.models:
|
||||
if model.path_exists():
|
||||
valid.append(model.name)
|
||||
else:
|
||||
missing.append(model.name)
|
||||
|
||||
return {"valid": valid, "missing": missing}
|
||||
|
||||
def get_model_count(self) -> int:
|
||||
"""Get total number of registered models"""
|
||||
return len(self.models)
|
||||
|
||||
def suggest_name(self, base_name: str) -> str:
|
||||
"""
|
||||
Suggest a unique name based on base_name
|
||||
|
||||
Args:
|
||||
base_name: Base name to derive from
|
||||
|
||||
Returns:
|
||||
A unique name (may have suffix like -2, -3 etc.)
|
||||
"""
|
||||
if not self.check_name_conflict(base_name):
|
||||
return base_name
|
||||
|
||||
counter = 2
|
||||
while True:
|
||||
candidate = f"{base_name}-{counter}"
|
||||
if not self.check_name_conflict(candidate):
|
||||
return candidate
|
||||
counter += 1
|
||||
Reference in New Issue
Block a user