From 56cbd69ac40c2aa275058d7911959d8ced1b3ba8 Mon Sep 17 00:00:00 2001 From: Oql <1692110604@qq.com> Date: Wed, 4 Feb 2026 16:44:54 +0800 Subject: [PATCH] kt-cli enhancement (#1834) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [feat]: redesign kt run interactive configuration with i18n support - Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port) - Add configuration save/load system (~/.ktransformers/run_configs.yaml) - Add i18n support for kt chat (en/zh translations) - Add universal input validators with auto-retry and Chinese comma support - Add port availability checker with auto-suggestion - Add parser configuration (--tool-call-parser, --reasoning-parser) - Remove tuna command and clean up redundant files - Fix: variable reference bug in run.py, filter to show only MoE models * [feat]: unify model selection UI and enable shared experts fusion by default - Unify kt run model selection table with kt model list display * Add Total size, MoE Size, Repo, and SHA256 status columns * Use consistent formatting and styling * Improve user decision-making with more information - Enable --disable-shared-experts-fusion by default * Change default value from False to True * Users can still override with --enable-shared-experts-fusion * [feat]: improve kt chat with performance metrics and better CJK support - Add performance metrics display after each response * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token) * Accurate input/output token counts using model tokenizer * Fallback to estimation if tokenizer unavailable * Metrics shown in dim style (not prominent) - Fix Chinese character input issues * Replace Prompt.ask() with console.input() for better CJK support * Fixes backspace deletion showing half-characters - Suppress NumPy subnormal warnings * Filter "The value of the smallest subnormal" warnings * Cleaner CLI output on certain hardware environments * [fix]: correct TTFT measurement in kt chat - Move start_time initialization before API call - Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms - Now correctly measures time from request sent to first token received * [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案 * [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力 * [docs]: 添加 Clawdbot 飞书接入教程链接 * [feat]: improve CLI table display, model verification, and chat experience - Add sequence number (#) column to all model tables by default - Filter kt edit to show only MoE GPU models (exclude AMX) - Extend kt model verify to check *.json and *.py files in addition to weights - Fix re-verification bug where repaired files caused false failures - Suppress tokenizer debug output in kt chat token counting * [fix]: fix cpu cores. --------- Co-authored-by: skqliao --- kt-kernel/python/cli/commands/chat.py | 227 +- kt-kernel/python/cli/commands/model.py | 2859 +++++++++++++++-- kt-kernel/python/cli/commands/quant.py | 411 ++- kt-kernel/python/cli/commands/run.py | 606 ++-- .../python/cli/completions/kt-completion.bash | 6 +- kt-kernel/python/cli/i18n.py | 690 ++++ kt-kernel/python/cli/main.py | 273 +- .../python/cli/utils/analyze_moe_model.py | 413 +++ kt-kernel/python/cli/utils/debug_configs.py | 118 + kt-kernel/python/cli/utils/download_helper.py | 146 + .../python/cli/utils/input_validators.py | 216 ++ .../python/cli/utils/kv_cache_calculator.py | 207 ++ kt-kernel/python/cli/utils/model_discovery.py | 250 ++ kt-kernel/python/cli/utils/model_scanner.py | 790 +++++ .../python/cli/utils/model_table_builder.py | 254 ++ kt-kernel/python/cli/utils/model_verifier.py | 918 ++++++ kt-kernel/python/cli/utils/port_checker.py | 57 + .../python/cli/utils/quant_interactive.py | 347 ++ kt-kernel/python/cli/utils/repo_detector.py | 364 +++ kt-kernel/python/cli/utils/run_configs.py | 111 + kt-kernel/python/cli/utils/run_interactive.py | 1084 +++++++ kt-kernel/python/cli/utils/tuna_engine.py | 459 +++ .../python/cli/utils/user_model_registry.py | 302 ++ 23 files changed, 10327 insertions(+), 781 deletions(-) create mode 100644 kt-kernel/python/cli/utils/analyze_moe_model.py create mode 100644 kt-kernel/python/cli/utils/debug_configs.py create mode 100644 kt-kernel/python/cli/utils/download_helper.py create mode 100644 kt-kernel/python/cli/utils/input_validators.py create mode 100644 kt-kernel/python/cli/utils/kv_cache_calculator.py create mode 100644 kt-kernel/python/cli/utils/model_discovery.py create mode 100644 kt-kernel/python/cli/utils/model_scanner.py create mode 100644 kt-kernel/python/cli/utils/model_table_builder.py create mode 100644 kt-kernel/python/cli/utils/model_verifier.py create mode 100644 kt-kernel/python/cli/utils/port_checker.py create mode 100644 kt-kernel/python/cli/utils/quant_interactive.py create mode 100644 kt-kernel/python/cli/utils/repo_detector.py create mode 100644 kt-kernel/python/cli/utils/run_configs.py create mode 100644 kt-kernel/python/cli/utils/run_interactive.py create mode 100644 kt-kernel/python/cli/utils/tuna_engine.py create mode 100644 kt-kernel/python/cli/utils/user_model_registry.py diff --git a/kt-kernel/python/cli/commands/chat.py b/kt-kernel/python/cli/commands/chat.py index 938f0f5..a10ed2b 100644 --- a/kt-kernel/python/cli/commands/chat.py +++ b/kt-kernel/python/cli/commands/chat.py @@ -96,9 +96,9 @@ def chat( kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters """ if not HAS_OPENAI: - print_error("OpenAI Python SDK is required for chat functionality.") + print_error(t("chat_openai_required")) console.print() - console.print("Install it with:") + console.print(t("chat_install_hint")) console.print(" pip install openai") raise typer.Exit(1) @@ -114,10 +114,10 @@ def chat( console.print() console.print( Panel.fit( - f"[bold cyan]KTransformers Chat[/bold cyan]\n\n" - f"Server: [yellow]{final_host}:{final_port}[/yellow]\n" - f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n" - f"[dim]Type '/help' for commands, '/quit' to exit[/dim]", + f"[bold cyan]{t('chat_title')}[/bold cyan]\n\n" + f"{t('chat_server')}: [yellow]{final_host}:{final_port}[/yellow]\n" + f"{t('chat_temperature')}: [cyan]{temperature}[/cyan] | {t('chat_max_tokens')}: [cyan]{max_tokens}[/cyan]\n\n" + f"[dim]{t('chat_help_hint')}[/dim]", border_style="cyan", ) ) @@ -152,31 +152,44 @@ def chat( ) # Test connection - print_info("Connecting to server...") + print_info(t("chat_connecting")) models = client.models.list() available_models = [m.id for m in models.data] if not available_models: - print_error("No models available on server") + print_error(t("chat_no_models")) raise typer.Exit(1) # Select model if model: if model not in available_models: - print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}") + print_warning(t("chat_model_not_found", model=model, available=", ".join(available_models))) selected_model = available_models[0] else: selected_model = model else: selected_model = available_models[0] - print_success(f"Connected to model: {selected_model}") + print_success(t("chat_connected", model=selected_model)) console.print() + # Load tokenizer for accurate token counting + tokenizer = None + try: + from transformers import AutoTokenizer + + # selected_model is the model path + tokenizer = AutoTokenizer.from_pretrained(selected_model, trust_remote_code=True) + console.print(f"[dim]Loaded tokenizer from {selected_model}[/dim]") + console.print() + except Exception as e: + console.print(f"[dim yellow]Warning: Could not load tokenizer, token counts will be estimated[/dim]") + console.print() + except Exception as e: - print_error(f"Failed to connect to server: {e}") + print_error(t("chat_connect_failed", error=str(e))) console.print() - console.print("Make sure the model server is running:") + console.print(t("chat_server_not_running")) console.print(" kt run ") raise typer.Exit(1) @@ -201,12 +214,12 @@ def chat( # Main chat loop try: while True: - # Get user input + # Get user input - use console.input() for better CJK character support try: - user_input = Prompt.ask("[bold green]You[/bold green]") + user_input = console.input(f"[bold green]{t('chat_user_prompt')}[/bold green]: ") except (EOFError, KeyboardInterrupt): console.print() - print_info("Goodbye!") + print_info(t("chat_goodbye")) break if not user_input.strip(): @@ -224,15 +237,19 @@ def chat( # Generate response console.print() - console.print("[bold cyan]Assistant[/bold cyan]") + console.print(f"[bold cyan]{t('chat_assistant_prompt')}[/bold cyan]") try: if stream: # Streaming response - response_content = _stream_response(client, selected_model, messages, temperature, max_tokens) + response_content = _stream_response( + client, selected_model, messages, temperature, max_tokens, tokenizer + ) else: # Non-streaming response - response_content = _generate_response(client, selected_model, messages, temperature, max_tokens) + response_content = _generate_response( + client, selected_model, messages, temperature, max_tokens, tokenizer + ) # Add assistant response to history messages.append({"role": "assistant", "content": response_content}) @@ -240,7 +257,7 @@ def chat( console.print() except Exception as e: - print_error(f"Error generating response: {e}") + print_error(t("chat_generation_error", error=str(e))) # Remove the user message that caused the error messages.pop() continue @@ -252,12 +269,12 @@ def chat( except KeyboardInterrupt: console.print() console.print() - print_info("Chat interrupted. Goodbye!") + print_info(t("chat_interrupted")) # Final history save if save_history and messages: _save_history(history_file, messages, selected_model) - console.print(f"[dim]History saved to: {history_file}[/dim]") + console.print(f"[dim]{t('chat_history_saved', path=str(history_file))}[/dim]") console.print() @@ -267,12 +284,22 @@ def _stream_response( messages: list, temperature: float, max_tokens: int, + tokenizer=None, ) -> str: """Generate streaming response and display in real-time.""" + import time + response_content = "" reasoning_content = "" + # Performance tracking + first_token_time = None + chunk_count = 0 + try: + # Start timing before sending request + start_time = time.time() + stream = client.chat.completions.create( model=model, messages=messages, @@ -282,33 +309,120 @@ def _stream_response( ) for chunk in stream: - delta = chunk.choices[0].delta - reasoning_delta = getattr(delta, "reasoning_content", None) - if reasoning_delta: - reasoning_content += reasoning_delta - console.print(reasoning_delta, end="", style="dim") - if delta.content: - content = delta.content - response_content += content - console.print(content, end="") + delta = chunk.choices[0].delta if chunk.choices else None + if delta: + reasoning_delta = getattr(delta, "reasoning_content", None) + if reasoning_delta: + if first_token_time is None: + first_token_time = time.time() + reasoning_content += reasoning_delta + console.print(reasoning_delta, end="", style="dim") + chunk_count += 1 + + if delta.content: + if first_token_time is None: + first_token_time = time.time() + content = delta.content + response_content += content + console.print(content, end="") + chunk_count += 1 console.print() # Newline after streaming + # Display performance metrics + end_time = time.time() + if first_token_time and chunk_count > 0: + ttft = first_token_time - start_time + total_time = end_time - start_time + + # Calculate TPOT based on chunks + if chunk_count > 1: + generation_time = total_time - ttft + tpot = generation_time / (chunk_count - 1) + else: + tpot = 0 + + # Calculate accurate token counts using tokenizer + if tokenizer: + input_tokens = _count_tokens_with_tokenizer(messages, tokenizer) + output_tokens = _count_tokens_with_tokenizer( + [{"role": "assistant", "content": response_content}], tokenizer + ) + token_prefix = "" + else: + # Fallback to estimation + input_tokens = _estimate_tokens(messages) + output_tokens = _estimate_tokens([{"role": "assistant", "content": response_content}]) + token_prefix = "~" + + # Build metrics display + metrics = f"[dim]Total: {total_time*1000:.0f}ms | TTFT: {ttft*1000:.0f}ms" + if tpot > 0: + metrics += f" | TPOT: {tpot*1000:.1f}ms" + metrics += f" | In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}" + metrics += "[/dim]" + + console.print(metrics) + except Exception as e: raise Exception(f"Streaming error: {e}") return response_content +def _count_tokens_with_tokenizer(messages: list, tokenizer) -> int: + """Count tokens accurately using the model's tokenizer.""" + try: + # Concatenate all message content + text = "" + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + # Simple format: role + content + text += f"{role}: {content}\n" + + # Encode and count tokens - suppress any debug output from custom tokenizers + import os + import sys + from contextlib import redirect_stdout, redirect_stderr + + with open(os.devnull, "w") as devnull: + with redirect_stdout(devnull), redirect_stderr(devnull): + tokens = tokenizer.encode(text, add_special_tokens=True) + return len(tokens) + except Exception: + # Fallback to estimation if tokenizer fails + return _estimate_tokens(messages) + + +def _estimate_tokens(messages: list) -> int: + """Estimate token count for messages (rough approximation).""" + total_chars = 0 + for msg in messages: + content = msg.get("content", "") + total_chars += len(content) + + # Rough estimation: + # - English: ~4 chars per token + # - Chinese: ~1.5 chars per token + # Use 2.5 as average + return max(1, int(total_chars / 2.5)) + + def _generate_response( client: "OpenAI", model: str, messages: list, temperature: float, max_tokens: int, + tokenizer=None, ) -> str: """Generate non-streaming response.""" + import time + try: + start_time = time.time() + response = client.chat.completions.create( model=model, messages=messages, @@ -317,12 +431,36 @@ def _generate_response( stream=False, ) + end_time = time.time() + total_time = end_time - start_time + content = response.choices[0].message.content # Display as markdown md = Markdown(content) console.print(md) + # Calculate accurate token counts using tokenizer + if tokenizer: + input_tokens = _count_tokens_with_tokenizer(messages, tokenizer) + output_tokens = _count_tokens_with_tokenizer([{"role": "assistant", "content": content}], tokenizer) + token_prefix = "" + else: + # Fallback to API usage or estimation + input_tokens = response.usage.prompt_tokens if response.usage else _estimate_tokens(messages) + output_tokens = ( + response.usage.completion_tokens + if response.usage + else _estimate_tokens([{"role": "assistant", "content": content}]) + ) + token_prefix = "" if response.usage else "~" + + # Display performance metrics + console.print( + f"[dim]Time: {total_time*1000:.0f}ms | " + f"In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}[/dim]" + ) + return content except Exception as e: @@ -335,20 +473,14 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens if cmd in ["/quit", "/exit", "/q"]: console.print() - print_info("Goodbye!") + print_info(t("chat_goodbye")) return False elif cmd in ["/help", "/h"]: console.print() console.print( Panel( - "[bold]Available Commands:[/bold]\n\n" - "/help, /h - Show this help message\n" - "/quit, /exit, /q - Exit chat\n" - "/clear, /c - Clear conversation history\n" - "/history, /hist - Show conversation history\n" - "/info, /i - Show current settings\n" - "/retry, /r - Regenerate last response", + f"[bold]{t('chat_help_title')}[/bold]\n\n{t('chat_help_content')}", title="Help", border_style="cyan", ) @@ -359,19 +491,19 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens elif cmd in ["/clear", "/c"]: messages.clear() console.print() - print_success("Conversation history cleared") + print_success(t("chat_history_cleared")) console.print() return True elif cmd in ["/history", "/hist"]: console.print() if not messages: - print_info("No conversation history") + print_info(t("chat_no_history")) else: console.print( Panel( _format_history(messages), - title=f"History ({len(messages)} messages)", + title=t("chat_history_title", count=len(messages)), border_style="cyan", ) ) @@ -382,10 +514,7 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens console.print() console.print( Panel( - f"[bold]Current Settings:[/bold]\n\n" - f"Temperature: [cyan]{temperature}[/cyan]\n" - f"Max tokens: [cyan]{max_tokens}[/cyan]\n" - f"Messages: [cyan]{len(messages)}[/cyan]", + f"[bold]{t('chat_info_title')}[/bold]\n\n{t('chat_info_content', temperature=temperature, max_tokens=max_tokens, messages=len(messages))}", title="Info", border_style="cyan", ) @@ -397,16 +526,16 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens if len(messages) >= 2 and messages[-1]["role"] == "assistant": # Remove last assistant response messages.pop() - print_info("Retrying last response...") + print_info(t("chat_retrying")) console.print() else: - print_warning("No previous response to retry") + print_warning(t("chat_no_retry")) console.print() return True else: - print_warning(f"Unknown command: {command}") - console.print("[dim]Type /help for available commands[/dim]") + print_warning(t("chat_unknown_command", command=command)) + console.print(f"[dim]{t('chat_unknown_hint')}[/dim]") console.print() return True diff --git a/kt-kernel/python/cli/commands/model.py b/kt-kernel/python/cli/commands/model.py index 772ef8b..1476ae7 100644 --- a/kt-kernel/python/cli/commands/model.py +++ b/kt-kernel/python/cli/commands/model.py @@ -6,22 +6,82 @@ Manages models: download, list, and storage paths. import os from pathlib import Path -from typing import Optional +from typing import Optional, List import typer from kt_kernel.cli.config.settings import get_settings -from kt_kernel.cli.i18n import t +from kt_kernel.cli.i18n import t, get_lang from kt_kernel.cli.utils.console import ( confirm, console, print_error, print_info, + print_step, print_success, print_warning, prompt_choice, ) + +# Common SHA256 status display mapping used across multiple commands +SHA256_STATUS_MAP = { + "not_checked": "[dim]Not Checked[/dim]", + "checking": "[yellow]Checking...[/yellow]", + "passed": "[green]✓ Passed[/green]", + "failed": "[red]✗ Failed[/red]", + "no_repo": "[dim]-[/dim]", +} + +# Plain text version for panels and verbose output +SHA256_STATUS_MAP_PLAIN = { + "not_checked": "Not Checked", + "checking": "Checking...", + "passed": "✓ Passed", + "failed": "✗ Failed", + "no_repo": "-", +} + + +def is_amx_weights(model_path) -> tuple[bool, int]: + """ + Determine if a model uses AMX weights and count NUMA nodes. + + Returns: + (is_amx, numa_count): Tuple where is_amx indicates AMX weights, + and numa_count is the number of NUMA nodes (0 if not AMX). + """ + import re + from pathlib import Path + from safetensors import safe_open + + model_path = Path(model_path) + safetensors_files = sorted(model_path.glob("*.safetensors")) + + if not safetensors_files: + return False, 0 + + numa_indices = set() + numa_pattern = re.compile(r"\.numa\.(\d+)\.") + + # Check first 3 files for NUMA keys + for file_path in safetensors_files[:3]: + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if ".numa." in key: + match = numa_pattern.search(key) + if match: + numa_indices.add(int(match.group(1))) + except Exception: + continue + + if not numa_indices: + return False, 0 + + return True, len(numa_indices) + + app = typer.Typer( help="Manage models and storage paths", invoke_without_command=True, @@ -36,76 +96,25 @@ def callback(ctx: typer.Context) -> None: Run without arguments to see available models. """ - # If no subcommand is provided, show the model list + # If no subcommand is provided, show the full model list if ctx.invoked_subcommand is None: - show_model_list() - - -def show_model_list() -> None: - """Display available models with their status and paths.""" - from rich.table import Table - from kt_kernel.cli.utils.model_registry import get_registry - from kt_kernel.cli.i18n import get_lang - - registry = get_registry() - settings = get_settings() - - console.print() - console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n") - - # Get local models mapping - local_models = {m.name: p for m, p in registry.find_local_models()} - - # Create table - table = Table(show_header=True, header_style="bold") - table.add_column(t("model_column_model"), style="cyan", no_wrap=True) - table.add_column(t("model_column_status"), justify="center") - - all_models = registry.list_all() - for model in all_models: - if model.name in local_models: - status = f"[green]✓ {t('model_status_local')}[/green]" - else: - status = "[dim]-[/dim]" - - table.add_row(model.name, status) - - console.print(table) - console.print() - - # Usage instructions - console.print(f"[bold]{t('model_usage_title')}:[/bold]") - console.print(f" • {t('model_usage_download')} [cyan]kt model download [/cyan]") - console.print(f" • {t('model_usage_list_local')} [cyan]kt model list --local[/cyan]") - console.print(f" • {t('model_usage_search')} [cyan]kt model search [/cyan]") - console.print() - - # Show model storage paths - model_paths = settings.get_model_paths() - console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]") - for path in model_paths: - marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]" - console.print(f" {marker} {path}") - console.print() + list_models(verbose=False, all_models=False, show_moe=True, no_cache=False) @app.command(name="download") def download( - model: Optional[str] = typer.Argument( + repo: Optional[str] = typer.Argument(None, help="Repository ID (optional, interactive mode if not provided)"), + local_dir: Optional[str] = typer.Option( None, - help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)", + "--local-dir", + "-d", + help="Local directory to download to (default: auto-detect from config)", ), - path: Optional[Path] = typer.Option( + repo_type: Optional[str] = typer.Option( None, - "--path", - "-p", - help="Custom download path", - ), - list_models: bool = typer.Option( - False, - "--list", - "-l", - help="List available models", + "--repo-type", + "-t", + help="Repository type: huggingface or modelscope", ), resume: bool = typer.Option( True, @@ -116,202 +125,934 @@ def download( False, "--yes", "-y", - help="Skip confirmation prompts", + help="Skip all prompts and use defaults", ), ) -> None: - """Download model weights from HuggingFace.""" + """Download model from HuggingFace or ModelScope (interactive mode).""" import subprocess - from kt_kernel.cli.i18n import get_lang - from kt_kernel.cli.utils.console import print_model_table, print_step - from kt_kernel.cli.utils.model_registry import get_registry + import os + from pathlib import Path + from rich.prompt import Prompt, Confirm + from rich.table import Table + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel + from kt_kernel.cli.utils.model_scanner import scan_single_path, format_size + from kt_kernel.cli.utils.model_verifier import check_huggingface_connectivity + from kt_kernel.cli.utils.download_helper import ( + list_remote_files_hf, + list_remote_files_ms, + filter_files_by_pattern, + calculate_total_size, + format_file_list_table, + verify_repo_exists, + ) settings = get_settings() - registry = get_registry() + user_registry = UserModelRegistry() console.print() - # List mode - if list_models or model is None: - print_step(t("download_list_title")) + # ========== Step 1: Select repository type ========== + if not repo_type and not yes: + console.print("[bold cyan]Step 1: Select Repository Source[/bold cyan]\n") + console.print(" 1. HuggingFace") + console.print(" 2. ModelScope") console.print() - models = registry.list_all() - model_dicts = [] - for m in models: - lang = get_lang() - desc = m.description_zh if lang == "zh" and m.description_zh else m.description - model_dicts.append( - { - "name": m.name, - "hf_repo": m.hf_repo, - "type": m.type, - "gpu_vram_gb": m.gpu_vram_gb, - "cpu_ram_gb": m.cpu_ram_gb, - } - ) + choice = Prompt.ask("Select source", choices=["1", "2"], default="1") + repo_type = "huggingface" if choice == "1" else "modelscope" + console.print() + elif not repo_type: + repo_type = "huggingface" # Default for --yes mode - print_model_table(model_dicts) + # Validate repo_type + if repo_type not in ["huggingface", "modelscope"]: + print_error(f"Invalid repo type: {repo_type}. Must be 'huggingface' or 'modelscope'") + raise typer.Exit(1) + + # Check HuggingFace connectivity and auto-switch to mirror if needed + use_mirror = False + if repo_type == "huggingface": + with console.status("[dim]Checking HuggingFace connectivity...[/dim]"): + is_accessible, message = check_huggingface_connectivity(timeout=5) + + if not is_accessible: + print_warning("HuggingFace Connection Failed") + console.print() + console.print(f" {message}") + console.print() + console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]") + console.print() + use_mirror = True + + # ========== Step 2: Input repository ID ========== + while True: + if not repo and not yes: + console.print("[bold cyan]Step 2: Enter Repository ID[/bold cyan]\n") + console.print(" Examples:") + console.print(" • HuggingFace: deepseek-ai/DeepSeek-V3") + console.print(" • ModelScope: Qwen/Qwen3-Coder-480B-A35B-Instruct") + console.print() + + repo = Prompt.ask("Repository ID") + console.print() + elif not repo: + print_error("Repository ID is required") + raise typer.Exit(1) + + # Verify repository exists + with console.status(f"[dim]Verifying repository: {repo}...[/dim]"): + exists, msg = verify_repo_exists(repo, repo_type, use_mirror) + + if exists: + print_success(f"✓ Repository found: {repo}") + console.print() + break + else: + print_error(msg) + console.print() + if yes: + raise typer.Exit(1) + repo = None # Reset to ask again + + # ========== Step 3: Input file pattern and preview files ========== + files_to_download = [] + file_pattern = "*" + + while True: + if not yes: + console.print("[bold cyan]Step 3: Select Files to Download[/bold cyan]\n") + console.print(" File pattern (glob syntax):") + console.print(" • * - All files (default)") + console.print(" • *.safetensors - Only safetensors files") + console.print(" • *.gguf - Only GGUF files") + console.print(" • *Q4_K_M.gguf - Specific GGUF quant") + console.print() + + pattern_input = Prompt.ask("File pattern", default="*") + file_pattern = pattern_input + console.print() + + # Fetch remote file list + with console.status(f"[dim]Fetching file list from {repo_type}...[/dim]"): + try: + if repo_type == "huggingface": + all_files = list_remote_files_hf(repo, use_mirror) + else: + all_files = list_remote_files_ms(repo) + + files_to_download = filter_files_by_pattern(all_files, file_pattern) + except Exception as e: + print_error(f"Failed to fetch file list: {e}") + raise typer.Exit(1) + + if not files_to_download: + print_warning(f"No files match pattern: {file_pattern}") + console.print() + if yes: + raise typer.Exit(1) + continue # Ask for pattern again + + # Display matched files + total_size = calculate_total_size(files_to_download) + print_success(f"Found {len(files_to_download)} files (Total: {format_size(total_size)})") console.print() - if model is None: - console.print(f"[dim]{t('model_download_usage_hint')}[/dim]") + file_table = format_file_list_table(files_to_download, max_display=10) + console.print(file_table) + console.print() + + # Confirm or retry + if yes: + break + + action = Prompt.ask("Action", choices=["continue", "retry", "cancel"], default="continue") + + if action == "continue": + console.print() + break + elif action == "cancel": + console.print() + print_info("Download cancelled") + console.print() + return + # else retry - loop continues + + # ========== Step 4: Select download path ========== + download_path = None + + if local_dir: + download_path = Path(os.path.expanduser(local_dir)).resolve() + elif not yes: + console.print("[bold cyan]Step 4: Select Download Location[/bold cyan]\n") + + # Get configured model paths + model_paths = settings.get_model_paths() + if not model_paths: + print_error("No model storage paths configured.") + console.print() + console.print(f" Add a path with: [cyan]kt model path-add [/cyan]") + console.print() + raise typer.Exit(1) + + # Display configured paths + console.print(" Configured storage paths:") + for i, path in enumerate(model_paths, 1): + console.print(f" {i}. {path}") + console.print(f" {len(model_paths) + 1}. Custom path (manual input)") + console.print() + + path_choice = Prompt.ask("Select path", choices=[str(i) for i in range(1, len(model_paths) + 2)], default="1") + + if int(path_choice) <= len(model_paths): + base_path = model_paths[int(path_choice) - 1] + else: + custom = Prompt.ask("Enter custom path") + base_path = Path(os.path.expanduser(custom)).resolve() + + console.print() + + # Ask for folder name + default_folder = repo.split("/")[-1] + folder_name = Prompt.ask("Folder name", default=default_folder) + + download_path = base_path / folder_name + console.print() + else: + # --yes mode: use default + model_paths = settings.get_model_paths() + if not model_paths: + print_error("No model storage paths configured.") + raise typer.Exit(1) + + default_folder = repo.split("/")[-1] + download_path = model_paths[0] / default_folder + + # ========== Step 5: Confirm and download ========== + print_info(f"Download destination: {download_path}") + console.print() + + # Check if path exists + if download_path.exists(): + existing = user_registry.find_by_path(str(download_path)) + if existing: + print_warning(f"Model already registered as: {existing.name}") + console.print() + if not yes and not Confirm.ask("Re-download anyway?", default=False): + return + else: + print_warning(f"Directory already exists: {download_path}") + if not yes and not Confirm.ask("Overwrite?", default=False): + return + console.print() + + # Final confirmation + if not yes: + console.print("[bold]Download Summary:[/bold]") + console.print(f" Source: {repo_type}:{repo}") + console.print( + f" Files: {len(files_to_download)} files ({format_size(calculate_total_size(files_to_download))})" + ) + console.print(f" Pattern: {file_pattern}") + console.print(f" Destination: {download_path}") + console.print() + + if not Confirm.ask("Start download?", default=True): + console.print() + print_info("Download cancelled") console.print() return - # Search for model - print_step(t("download_searching", name=model)) - - # Check if it's a direct HuggingFace repo path - if "/" in model: - hf_repo = model - model_info = None - model_name = model.split("/")[-1] - else: - matches = registry.search(model) - - if not matches: - print_error(t("run_model_not_found", name=model)) - console.print() - console.print(t("model_download_list_hint")) - console.print(t("model_download_hf_hint")) - raise typer.Exit(1) - - if len(matches) == 1: - model_info = matches[0] - else: - console.print() - print_info(t("download_multiple_found")) - choices = [f"{m.name} ({m.hf_repo})" for m in matches] - selected = prompt_choice(t("download_select"), choices) - idx = choices.index(selected) - model_info = matches[idx] - - hf_repo = model_info.hf_repo - model_name = model_info.name - - print_success(t("download_found", name=hf_repo)) - - # Determine download path - if path is None: - download_path = settings.models_dir / model_name.replace(" ", "-") - else: - download_path = path - + # Download console.print() - print_info(t("download_destination", path=str(download_path))) - - # Check if already exists - if download_path.exists() and (download_path / "config.json").exists(): - print_warning(t("download_already_exists", path=str(download_path))) - if not yes: - if not confirm(t("download_overwrite_prompt"), default=False): - raise typer.Abort() - - # Confirm download - if not yes: - console.print() - if not confirm(t("prompt_continue")): - raise typer.Abort() - - # Download using huggingface-cli + print_step("Downloading model files...") console.print() - print_step(t("download_starting")) - cmd = [ - "huggingface-cli", - "download", - hf_repo, - "--local-dir", - str(download_path), - ] - - if resume: - cmd.append("--resume-download") - - # Add mirror if configured - mirror = settings.get("download.mirror", "") - if mirror: - cmd.extend(["--endpoint", mirror]) + # Set mirror for HuggingFace if needed + original_hf_endpoint = os.environ.get("HF_ENDPOINT") + if use_mirror and repo_type == "huggingface" and not original_hf_endpoint: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" try: - process = subprocess.run(cmd, check=True) + if repo_type == "huggingface": + from huggingface_hub import snapshot_download - console.print() - print_success(t("download_complete")) - console.print() - console.print(f" {t('model_saved_to', path=download_path)}") - console.print() - console.print(f" {t('model_start_with', name=model_name)}") - console.print() + snapshot_download( + repo_id=repo, + local_dir=str(download_path), + allow_patterns=file_pattern if file_pattern != "*" else None, + local_dir_use_symlinks=False, + resume_download=resume, + ) - except subprocess.CalledProcessError as e: - print_error(t("model_download_failed", error=str(e))) + else: # modelscope + from modelscope.hub.snapshot_download import snapshot_download + + snapshot_download( + model_id=repo, + local_dir=str(download_path), + allow_file_pattern=file_pattern if file_pattern != "*" else None, + ) + + except ImportError as e: + pkg = "huggingface_hub" if repo_type == "huggingface" else "modelscope" + print_error(f"{pkg} not installed. Install: pip install {pkg}") raise typer.Exit(1) - except FileNotFoundError: - print_error(t("model_hf_cli_not_found")) + except Exception as e: + print_error(f"Download failed: {e}") + raise typer.Exit(1) + finally: + # Restore HF_ENDPOINT + if use_mirror and repo_type == "huggingface" and not original_hf_endpoint: + os.environ.pop("HF_ENDPOINT", None) + elif original_hf_endpoint: + os.environ["HF_ENDPOINT"] = original_hf_endpoint + + # ========== Step 6: Scan and register ========== + console.print() + print_success("Download complete!") + + console.print() + print_step("Scanning downloaded model...") + + try: + scanned = scan_single_path(download_path) + except Exception as e: + print_error(f"Failed to scan model: {e}") + console.print() + console.print(f" You can manually add it: [cyan]kt model add {download_path}[/cyan]") + console.print() + raise typer.Exit(1) + + if not scanned: + print_warning("No model files found in downloaded directory.") + console.print() + console.print(" Supported formats: .safetensors, .gguf") + console.print() + return + + # Auto-generate model name + model_name = download_path.name + if user_registry.check_name_conflict(model_name): + model_name = user_registry.suggest_name(model_name) + + # Create and register model + user_model = UserModel( + name=model_name, + path=str(download_path), + format=scanned.format, + repo_type=repo_type, + repo_id=repo, + sha256_status="not_checked", + ) + + try: + user_registry.add_model(user_model) + console.print() + print_success(f"Model registered as: {model_name}") + console.print() + console.print(f" View details: [cyan]kt model info {model_name}[/cyan]") + console.print(f" Run model: [cyan]kt run {model_name}[/cyan]") + console.print(f" Verify integrity: [cyan]kt model verify {model_name}[/cyan]") + console.print() + except Exception as e: + print_error(f"Failed to register model: {e}") + console.print() + console.print(f" You can manually add it: [cyan]kt model add {download_path}[/cyan]") + console.print() raise typer.Exit(1) @app.command(name="list") def list_models( - local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"), + all_models: bool = typer.Option(False, "--all", help="Show all models (reserved for future use)"), + show_moe: bool = typer.Option(True, "--moe/--no-moe", help="Show MoE model information (default: enabled)"), + no_cache: bool = typer.Option(False, "--no-cache", help="Force re-analyze all models, ignore cache"), ) -> None: - """List available models.""" + """List user-registered models.""" from rich.table import Table - from kt_kernel.cli.utils.model_registry import get_registry + from rich.panel import Panel + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.model_scanner import format_size + import sys + from pathlib import Path as PathLib + + # Try to import analyze_moe_model from multiple locations + analyze_moe_model = None + try: + # Try 1: From kt_kernel.cli.utils + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model + except ImportError: + try: + # Try 2: From parent directories + analyze_moe_path = PathLib(__file__).parent.parent.parent.parent.parent.parent / "analyze_moe_model.py" + if analyze_moe_path.exists(): + sys.path.insert(0, str(analyze_moe_path.parent)) + from analyze_moe_model import analyze_moe_model + except (ImportError, Exception): + try: + # Try 3: Absolute path + sys.path.insert(0, "/mnt/data2/ljq/ktransformers") + from analyze_moe_model import analyze_moe_model + except (ImportError, Exception): + analyze_moe_model = None + + registry = UserModelRegistry() + models = registry.list_models() - registry = get_registry() console.print() - if local_only: - # Show only local models - local_models = registry.find_local_models() + if not models: + print_warning(t("model_no_registered_models")) + console.print() + console.print(f" {t('model_scan_hint')} [cyan]kt model scan[/cyan]") + console.print(f" {t('model_add_hint')} [cyan]kt model add [/cyan]") + console.print() + return - if not local_models: - print_warning(t("model_no_local_models")) + # Check for models with non-existent paths and remove them automatically + models_to_remove = [] + for model in models: + if not model.path_exists(): + models_to_remove.append(model) + + if models_to_remove: + console.print(f"[yellow]Found {len(models_to_remove)} model(s) with non-existent paths:[/yellow]") + for model in models_to_remove: + console.print(f" [dim]✗ {model.name}: {model.path}[/dim]") + registry.remove_model(model.name) + console.print(f"[green]✓ Automatically removed {len(models_to_remove)} model(s) with missing paths[/green]") + console.print() + + # Refresh the models list + models = registry.list_models() + + if not models: + console.print(f"[dim]No models remaining after cleanup.[/dim]") console.print() - console.print(f" {t('model_download_hint')} [cyan]kt model download [/cyan]") + console.print(f" {t('model_scan_hint')} [cyan]kt model scan[/cyan]") + console.print(f" {t('model_add_hint')} [cyan]kt model add [/cyan]") console.print() return - table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold") - table.add_column(t("model_column_model"), style="cyan", no_wrap=True) - if verbose: - table.add_column(t("model_column_local_path"), style="dim") + if verbose: + # Verbose mode: detailed cards + console.print(f"[bold cyan]{t('model_registered_models_title')}[/bold cyan]\n") - for model_info, model_path in local_models: - if verbose: - table.add_row(model_info.name, str(model_path)) + for i, model in enumerate(models, 1): + # Check if path exists + path_status = "[green]✓ Exists[/green]" if model.path_exists() else "[red]✗ Missing[/red]" + + # Format repo info + if model.repo_id: + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + repo_info = f"{repo_abbr}:{model.repo_id}" else: - table.add_row(model_info.name) + repo_info = "-" - console.print(table) + # Format SHA256 status + sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status) + + # Calculate folder size if exists + if model.path_exists(): + from pathlib import Path + + path_obj = Path(model.path) + try: + if model.format == "safetensors": + files = list(path_obj.glob("*.safetensors")) + else: + files = list(path_obj.glob("*.gguf")) + + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_str = format_size(total_size) + file_count = len(files) + size_info = f"{size_str} ({file_count} files)" + except: + size_info = "Unknown" + else: + size_info = "-" + + # Create panel content + content = f"""[bold]Path:[/bold] {model.path} +[bold]Format:[/bold] {model.format} +[bold]Repo:[/bold] {repo_info} +[bold]SHA256:[/bold] {sha256_display} +[bold]Size:[/bold] {size_info} +[bold]Status:[/bold] {path_status}""" + + panel = Panel(content, title=f"[cyan]{model.name}[/cyan]", border_style="cyan", padding=(0, 1)) + console.print(panel) + + console.print() + console.print(f"[dim]Total: {len(models)} model(s)[/dim]\n") else: - # Show all registered models - all_models = registry.list_all() - local_models_dict = {m.name: p for m, p in registry.find_local_models()} + # Compact mode: separate tables by model type + from rich.align import Align + from pathlib import Path - table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold") - table.add_column(t("model_column_model"), style="cyan", no_wrap=True) - table.add_column(t("model_column_status"), justify="center") - if verbose: - table.add_column(t("model_column_local_path"), style="dim") + # Categorize models + gguf_models = [] + amx_models = [] + gpu_models = [] - for model in all_models: - if model.name in local_models_dict: - status = f"[green]✓ {t('model_status_local')}[/green]" - local_path = str(local_models_dict[model.name]) + for model in models: + if model.format == "gguf": + gguf_models.append(model) + elif model.format == "safetensors" and model.path_exists(): + is_amx, numa_count = is_amx_weights(model.path) + if is_amx: + amx_models.append((model, numa_count)) + else: + gpu_models.append(model) else: - status = "[dim]-[/dim]" - local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]" + gpu_models.append(model) - if verbose: - table.add_row(model.name, status, local_path) + # Pre-analyze GPU MoE models concurrently if enabled + moe_results = {} + moe_failed_models = [] # Track models that failed MoE analysis + if show_moe and analyze_moe_model and gpu_models: + from concurrent.futures import ThreadPoolExecutor, as_completed + import threading + + # Collect GPU models that need MoE analysis + # Priority: use cached MoE info from UserModel, only analyze if is_moe is None + models_to_analyze = [] + models_need_update = [] # Track models that need registry update + + for model in gpu_models: + # Check if MoE info is already cached in UserModel (and not using --no-cache) + if not no_cache and model.is_moe is not None: + # Use cached info from UserModel + if model.is_moe: + moe_results[model.name] = { + "is_moe": True, + "num_experts": model.moe_num_experts, + "num_experts_per_tok": model.moe_num_experts_per_tok, + "cached": True, + } + # If is_moe is False, don't add to moe_results + else: + # Need to analyze (is_moe is None or --no-cache) + path_obj = Path(model.path) + models_to_analyze.append((model.name, str(path_obj))) + models_need_update.append(model) + + if models_to_analyze: + # Use lock for thread-safe console output + print_lock = threading.Lock() + completed_count = [0] # Use list to allow modification in nested function + + def analyze_with_progress(model_info): + model_name, model_path = model_info + try: + with print_lock: + console.print(f"[dim]Analyzing MoE: {model_name}...[/dim]") + result = analyze_moe_model(model_path, use_cache=not no_cache) + + # Check if analysis returned valid results + if result is None or result.get("num_experts", 0) == 0: + with print_lock: + completed_count[0] += 1 + console.print( + f"[dim]✗ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} - Not a MoE model or analysis failed[/dim]" + ) + return (model_name, None, "Not a MoE model or analysis failed") + + with print_lock: + completed_count[0] += 1 + cached_tag = "[green](cached)[/green]" if result and result.get("cached") else "" + console.print( + f"[dim]✓ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} {cached_tag}[/dim]" + ) + return (model_name, result, None) + except Exception as e: + with print_lock: + completed_count[0] += 1 + error_msg = str(e)[:80] + console.print( + f"[dim]✗ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} - Error: {error_msg}[/dim]" + ) + return (model_name, None, error_msg) + + if no_cache: + console.print(f"\n[yellow]Force re-analyzing (--no-cache): ignoring cached results[/yellow]") + console.print( + f"\n[cyan]Analyzing {len(models_to_analyze)} MoE model(s) with {min(16, len(models_to_analyze))} threads...[/cyan]\n" + ) + + # Analyze concurrently with up to 16 workers + with ThreadPoolExecutor(max_workers=16) as executor: + futures = { + executor.submit(analyze_with_progress, model_info): model_info + for model_info in models_to_analyze + } + + for future in as_completed(futures): + model_name, result, error = future.result() + if error: + # Find the model object + failed_model = next((m for m in gpu_models if m.name == model_name), None) + if failed_model: + moe_failed_models.append((failed_model, error)) + # Update model registry: mark as non-MoE + registry.update_model(model_name, {"is_moe": False}) + else: + moe_results[model_name] = result + # Update model registry with MoE info + if result and result.get("is_moe"): + registry.update_model( + model_name, + { + "is_moe": True, + "moe_num_experts": result.get("num_experts"), + "moe_num_experts_per_tok": result.get("num_experts_per_tok"), + }, + ) + else: + registry.update_model(model_name, {"is_moe": False}) + + console.print(f"\n[green]✓ MoE analysis complete[/green]\n") + + # Remove failed models from gpu_models list + if moe_failed_models: + failed_names = {m.name for m, _ in moe_failed_models} + gpu_models = [m for m in gpu_models if m.name not in failed_names] + + # Separate MoE and non-MoE GPU models + moe_gpu_models = [] + non_moe_gpu_models = [] + for model in gpu_models: + if model.name in moe_results: + moe_gpu_models.append(model) else: - table.add_row(model.name, status) + non_moe_gpu_models.append(model) - console.print(table) + # Count failed MoE models (these are also non-MoE) + total_non_moe_count = len(non_moe_gpu_models) + len(moe_failed_models) + + # Filter display based on --all flag + if not all_models: + # Default: only show MoE models + gpu_models_to_display = moe_gpu_models + show_failed_table = False + else: + # --all: show all GPU models including non-MoE and failed + gpu_models_to_display = gpu_models + show_failed_table = True + total_non_moe_count = 0 # Don't show hint when displaying all + + # Helper function to create table rows + def format_model_row(model, moe_info=None, numa_count=None): + from kt_kernel.cli.utils.model_scanner import format_size + + # Calculate size + if model.path_exists(): + path_obj = Path(model.path) + try: + if model.format == "safetensors": + files = list(path_obj.glob("*.safetensors")) + else: + files = list(path_obj.glob("*.gguf")) + + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_display = format_size(total_size) + except: + size_display = "[dim]-[/dim]" + else: + size_display = "[dim]-[/dim]" + + # Format repo info + if model.repo_id: + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + repo_display = f"{repo_abbr}:{model.repo_id}" + else: + repo_display = "[dim]-[/dim]" + + # Format SHA256 status + sha256_display = SHA256_STATUS_MAP.get(model.sha256_status, model.sha256_status) + + row = [model.name, model.path, size_display] + + # Add type-specific columns + if numa_count is not None: + # AMX model + row.append(f"[yellow]{numa_count} NUMA[/yellow]") + elif moe_info: + # GPU MoE model + experts_display = f"[yellow]{moe_info['num_experts']}[/yellow]" + activated_display = f"[green]{moe_info['num_experts_per_tok']}[/green]" + moe_total_display = f"[cyan]{size_display}[/cyan]" + row.extend([experts_display, activated_display, moe_total_display]) + elif show_moe and analyze_moe_model and model.format == "safetensors": + # GPU non-MoE model + row.extend(["[dim]-[/dim]", "[dim]-[/dim]", "[dim]-[/dim]"]) + + row.extend([repo_display, sha256_display]) + return row + + # Display tables + title = Align.center(f"[bold cyan]{t('model_registered_models_title')}[/bold cyan]") + console.print(title) + console.print() + + # Table 1: GGUF Models (Llamafile) + if gguf_models: + console.print("[bold yellow]GGUF Models (Llamafile)[/bold yellow]") + table = Table(show_header=True, header_style="bold") + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column(t("model_column_name"), style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column(t("model_column_repo"), style="dim", overflow="fold") + table.add_column(t("model_column_sha256"), justify="center") + + for i, model in enumerate(gguf_models, 1): + row = [str(i)] + format_model_row(model) + table.add_row(*row) + + console.print(table) + console.print() + + # Table 2: AMX Models + if amx_models: + from kt_kernel.cli.utils.model_scanner import format_size + import json + + console.print("[bold magenta]AMX Models (CPU)[/bold magenta]") + table = Table(show_header=True, header_style="bold", show_lines=False) + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column(t("model_column_name"), style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Method", justify="center", style="yellow") + table.add_column("NUMA", justify="center", style="green") + table.add_column("Source", style="dim", overflow="fold") + + # Build reverse map: AMX model ID -> GPU models using it + amx_used_by_gpu = {} # {amx_model_id: [gpu_model_names]} + for model, _ in amx_models: + if model.gpu_model_ids: + # This AMX is linked to these GPU models + gpu_names = [] + for gpu_id in model.gpu_model_ids: + # Find GPU model by ID + for gpu_model in gpu_models: + if gpu_model.id == gpu_id: + gpu_names.append(gpu_model.name) + break + if gpu_names: + amx_used_by_gpu[model.id] = gpu_names + + for i, (model, numa_count) in enumerate(amx_models, 1): + # Calculate size + if model.path_exists(): + path_obj = Path(model.path) + try: + files = list(path_obj.glob("*.safetensors")) + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_display = format_size(total_size) + except: + size_display = "[dim]-[/dim]" + else: + size_display = "[dim]-[/dim]" + + # Read AMX metadata from config.json (fallback if not in UserModel) + method_from_config = None + numa_from_config = None + if model.path_exists(): + config_path = Path(model.path) / "config.json" + if config_path.exists(): + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + amx_quant = config.get("amx_quantization", {}) + if amx_quant.get("converted"): + method_from_config = amx_quant.get("method") + numa_from_config = amx_quant.get("numa_count") + except: + pass + + # AMX-specific metadata (priority: UserModel > config.json > detected numa_count) + method_display = ( + model.amx_quant_method.upper() + if model.amx_quant_method + else method_from_config.upper() if method_from_config else "[dim]?[/dim]" + ) + numa_display = ( + str(model.amx_numa_nodes) + if model.amx_numa_nodes + else ( + str(numa_from_config) if numa_from_config else str(numa_count) if numa_count else "[dim]?[/dim]" + ) + ) + source_display = model.amx_source_model if model.amx_source_model else "[dim]-[/dim]" + + table.add_row( + str(i), model.name, model.path, size_display, method_display, numa_display, source_display + ) + + # Add linked GPU models info below this AMX model + if model.id in amx_used_by_gpu: + gpu_list = amx_used_by_gpu[model.id] + gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list]) + # Create a sub-row with empty cells except for the first column (7 columns total with #) + sub_row = ["", f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""] + table.add_row(*sub_row, style="dim") + + console.print(table) + console.print() + + # Table 3: GPU Models (Safetensors) + if gpu_models_to_display: + console.print("[bold green]GPU Models (Safetensors)[/bold green]") + table = Table(show_header=True, header_style="bold", show_lines=False) + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column(t("model_column_name"), style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + + if show_moe and analyze_moe_model: + table.add_column("Exps", justify="center", style="yellow") + table.add_column("Act", justify="center", style="green") + table.add_column("MoE Size", justify="right", style="cyan") + + table.add_column(t("model_column_repo"), style="dim", overflow="fold") + table.add_column(t("model_column_sha256"), justify="center") + + # Build a map of GPU model UUID -> attached CPU models + attached_cpu_models = {} # {gpu_model_id: [(cpu_model, type)]} + for model in gguf_models: + if model.gpu_model_ids: + for gpu_id in model.gpu_model_ids: + if gpu_id not in attached_cpu_models: + attached_cpu_models[gpu_id] = [] + attached_cpu_models[gpu_id].append((model, "GGUF")) + + for model, numa_count in amx_models: + if model.gpu_model_ids: + for gpu_id in model.gpu_model_ids: + if gpu_id not in attached_cpu_models: + attached_cpu_models[gpu_id] = [] + attached_cpu_models[gpu_id].append((model, "AMX")) + + for i, model in enumerate(gpu_models_to_display, 1): + moe_info = moe_results.get(model.name) if show_moe and analyze_moe_model else None + row = [str(i)] + format_model_row(model, moe_info=moe_info) + table.add_row(*row) + + # Add attached CPU models info below this GPU model (using UUID matching) + if model.id in attached_cpu_models: + cpu_list = attached_cpu_models[model.id] + cpu_names = ", ".join([f"[dim]{m.name} ({t})[/dim]" for m, t in cpu_list]) + # Create a sub-row with empty cells except for the first column + num_cols = len(row) + sub_row = ["", f" [dim]↳ CPU: {cpu_names}[/dim]"] + [""] * (num_cols - 2) + table.add_row(*sub_row, style="dim") + + console.print(table) + console.print() + + # Table 4: Failed MoE Analysis (only show with --all) + if show_failed_table and moe_failed_models: + console.print("[bold red]Failed MoE Analysis[/bold red]") + console.print("[yellow]These models may not be MoE models or have analysis errors:[/yellow]\n") + table = Table(show_header=True, header_style="bold") + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column(t("model_column_name"), style="red", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Error", style="yellow", overflow="fold") + + for i, (model, error) in enumerate(moe_failed_models, 1): + from kt_kernel.cli.utils.model_scanner import format_size + + if model.path_exists(): + path_obj = Path(model.path) + try: + files = list(path_obj.glob("*.safetensors")) + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_display = format_size(total_size) + except: + size_display = "[dim]-[/dim]" + else: + size_display = "[dim]-[/dim]" + + table.add_row(str(i), model.name, model.path, size_display, error) + + console.print(table) + console.print() + + # Show hint if non-MoE models are hidden (display before summary) + if total_non_moe_count > 0: + hint_text = t("model_non_moe_hidden_hint", count=total_non_moe_count) + console.print(f"[dim]{hint_text}[/dim]") + console.print() + + # Summary + total_count = len(gguf_models) + len(amx_models) + len(gpu_models) + failed_count = len(moe_failed_models) + if failed_count > 0: + console.print( + f"[dim]Total: {total_count} model(s) | GGUF: {len(gguf_models)} | AMX: {len(amx_models)} | GPU: {len(gpu_models)} | [red]Failed: {failed_count}[/red][/dim]\n" + ) + else: + console.print( + f"[dim]Total: {total_count} model(s) | GGUF: {len(gguf_models)} | AMX: {len(amx_models)} | GPU: {len(gpu_models)}[/dim]\n" + ) + + # Show usage hints (only in non-verbose mode) + if not verbose and models: + console.print(f"[bold cyan]{t('model_usage_title')}[/bold cyan]") + console.print(f" {t('model_usage_info'):<17} [cyan]kt model info [/cyan]") + console.print(f" {t('model_usage_edit'):<17} [cyan]kt model edit [/cyan]") + console.print(f" {t('model_usage_verify'):<17} [cyan]kt model verify [/cyan]") + console.print(f" {t('model_usage_quant'):<17} [cyan]kt quant [/cyan]") + console.print(f" {t('model_usage_run'):<17} [cyan]kt run [/cyan]") + console.print() + console.print(f" {t('model_usage_scan'):<17} [cyan]kt model scan[/cyan]") + console.print(f" {t('model_usage_add'):<17} [cyan]kt model add [/cyan]") + console.print() + + +@app.command(name="clear-cache") +def clear_cache() -> None: + """Clear MoE analysis cache.""" + from pathlib import Path + import json + + cache_file = Path.home() / ".ktransformers" / "cache" / "moe_analysis.json" + + if not cache_file.exists(): + console.print() + console.print("[dim]No MoE cache found.[/dim]") + console.print() + return + + # Read cache to count entries + try: + with open(cache_file, "r") as f: + cache_data = json.load(f) + cache_count = len(cache_data) + except Exception: + cache_count = 0 + + if cache_count == 0: + console.print() + console.print("[dim]MoE cache is empty.[/dim]") + console.print() + return + + console.print() + console.print(f"[yellow]Found {cache_count} cached model(s) in:[/yellow]") + console.print(f" {cache_file}") + console.print() + + if confirm("Clear all MoE analysis cache?", default=False): + cache_file.unlink() + console.print(f"[green]✓ Cleared cache for {cache_count} model(s)[/green]") + else: + console.print("[dim]Cache clear cancelled.[/dim]") console.print() @@ -332,6 +1073,95 @@ def path_list() -> None: console.print() +@app.command(name="link-cpu") +def link_cpu( + cpu_model: str = typer.Argument(..., help="Name of the CPU model (GGUF/AMX)"), + gpu_models: List[str] = typer.Argument(..., help="Name(s) of GPU model(s) to link with"), +) -> None: + """Link a CPU model (GGUF/AMX) with one or more GPU models for joint startup.""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + + # Check if CPU model exists + cpu_model_obj = registry.get_model(cpu_model) + if not cpu_model_obj: + print_error(f"CPU model '{cpu_model}' not found in registry.") + console.print() + console.print(f" Use [cyan]kt model list[/cyan] to see registered models") + console.print() + raise typer.Exit(1) + + # Check if it's actually a CPU model (GGUF or AMX) + if cpu_model_obj.format == "safetensors": + # Check if it's AMX by looking for .numa. pattern + is_amx, _ = is_amx_weights(cpu_model_obj.path) + if not is_amx: + print_error(f"Model '{cpu_model}' is a GPU model (safetensors), not a CPU model.") + console.print() + console.print(f" Only GGUF and AMX models can be linked to GPU models") + console.print() + raise typer.Exit(1) + + # Verify all GPU models exist and collect their UUIDs + gpu_model_uuids = [] + missing_models = [] + for gpu_name in gpu_models: + gpu_model_obj = registry.get_model(gpu_name) + if not gpu_model_obj: + missing_models.append(gpu_name) + else: + gpu_model_uuids.append(gpu_model_obj.id) + + if missing_models: + print_error(f"GPU model(s) not found: {', '.join(missing_models)}") + console.print() + console.print(f" Use [cyan]kt model list[/cyan] to see registered models") + console.print() + raise typer.Exit(1) + + # Update the CPU model with GPU links (using UUIDs for stability) + registry.update_model(cpu_model, {"gpu_model_ids": gpu_model_uuids}) + + console.print() + print_success(f"Linked CPU model '{cpu_model}' with GPU model(s):") + for gpu_name in gpu_models: + console.print(f" [green]✓[/green] {gpu_name}") + console.print() + console.print(f" View the relationship with [cyan]kt model list[/cyan]") + console.print() + + +@app.command(name="unlink-cpu") +def unlink_cpu( + cpu_model: str = typer.Argument(..., help="Name of the CPU model to unlink"), +) -> None: + """Remove GPU model links from a CPU model.""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + + # Check if model exists + model = registry.get_model(cpu_model) + if not model: + print_error(f"Model '{cpu_model}' not found in registry.") + console.print() + raise typer.Exit(1) + + if not model.gpu_model_ids: + console.print() + console.print(f"[yellow]Model '{cpu_model}' has no GPU links.[/yellow]") + console.print() + return + + # Remove links + registry.update_model(cpu_model, {"gpu_model_ids": None}) + + console.print() + print_success(f"Removed all GPU links from '{cpu_model}'") + console.print() + + @app.command(name="path-add") def path_add( path: str = typer.Argument(..., help="Path to add"), @@ -376,34 +1206,1605 @@ def path_remove( raise typer.Exit(1) -@app.command(name="search") -def search( - query: str = typer.Argument(..., help="Search query (model name or keyword)"), +@app.command(name="scan") +def scan( + min_size: float = typer.Option(2.0, "--min-size", help="Minimum model file size in GB (default: 2.0)"), + max_depth: int = typer.Option(6, "--max-depth", help="Maximum search depth (default: 6)"), ) -> None: - """Search for models in the registry.""" - from rich.table import Table - from kt_kernel.cli.utils.model_registry import get_registry + """Perform global scan for models and add new ones to registry.""" + from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary + from kt_kernel.cli.config.settings import get_settings - registry = get_registry() - matches = registry.search(query) + settings = get_settings() + lang = settings.get("general.language", "en") + + console.print() + if lang == "zh": + print_info("全局扫描模型权重") + console.print() + else: + print_info("Global Model Scan") + console.print() + + try: + total_found, new_found, registered = discover_and_register_global( + min_size_gb=min_size, max_depth=max_depth, show_progress=True, lang=lang + ) + + format_discovery_summary( + total_found=total_found, + new_found=new_found, + registered=registered, + lang=lang, + show_models=True, + max_show=20, + ) + + if new_found > 0: + console.print() + if lang == "zh": + console.print("[dim]下一步:[/dim]") + console.print(f" • 查看模型列表: [cyan]kt model list[/cyan]") + console.print(f" • 编辑模型信息: [cyan]kt model edit [/cyan]") + console.print(f" • 验证模型: [cyan]kt model verify [/cyan]") + else: + console.print("[dim]Next steps:[/dim]") + console.print(f" • View model list: [cyan]kt model list[/cyan]") + console.print(f" • Edit model info: [cyan]kt model edit [/cyan]") + console.print(f" • Verify models: [cyan]kt model verify [/cyan]") + console.print() + + except Exception as e: + print_error(f"Scan failed: {e}") + raise typer.Exit(1) + + +@app.command(name="add") +def add_model( + path: str = typer.Argument(..., help="Path to scan for models"), +) -> None: + """Scan a directory and add all found models to the registry.""" + from pathlib import Path + from kt_kernel.cli.utils.model_discovery import discover_and_register_path + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + lang = settings.get("general.language", "en") + + # Expand and validate path + path_obj = Path(os.path.expanduser(path)).resolve() + + if not path_obj.exists(): + print_error(f"Path does not exist: {path_obj}") + raise typer.Exit(1) + + if not path_obj.is_dir(): + print_error(f"Not a directory: {path_obj}") + raise typer.Exit(1) + + # Scan and register models + console.print() + try: + total_found, new_found, registered = discover_and_register_path( + path=str(path_obj), min_size_gb=2.0, existing_paths=None, show_progress=True, lang=lang + ) + + console.print() + if new_found == 0: + if total_found > 0: + if lang == "zh": + console.print(f"[yellow]在此路径找到 {total_found} 个模型,但所有模型均已在列表中[/yellow]") + else: + console.print( + f"[yellow]Found {total_found} models in this path, but all already in the list[/yellow]" + ) + else: + if lang == "zh": + console.print("[yellow]未找到模型[/yellow]") + console.print() + console.print(" 支持的格式: *.gguf, *.safetensors (需要 config.json)") + else: + console.print("[yellow]No models found[/yellow]") + console.print() + console.print(" Supported formats: *.gguf, *.safetensors (with config.json)") + else: + if lang == "zh": + console.print( + f"[green]✓[/green] 在此路径找到 {total_found} 个模型,成功添加 {len(registered)} 个新模型" + ) + else: + console.print( + f"[green]✓[/green] Found {total_found} models in this path, added {len(registered)} new models" + ) + + if registered: + console.print() + if lang == "zh": + console.print("[dim]新添加的模型:[/dim]") + else: + console.print("[dim]Newly added models:[/dim]") + + for model in registered: + console.print(f" • {model.name} ({model.format})") + console.print(f" [dim]{model.path}[/dim]") + + console.print() + + except Exception as e: + print_error(f"Failed to scan path: {e}") + raise typer.Exit(1) + + +@app.command(name="edit") +def edit_model( + name: Optional[str] = typer.Argument( + None, help="Name of model to edit (optional - will show selection if not provided)" + ), +) -> None: + """Edit model information interactively.""" + from rich.prompt import Prompt, Confirm + from rich.panel import Panel + from rich.table import Table + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + + # If no name provided, show interactive selection + if name is None: + all_models = registry.list_models() + + # Filter to only show MoE GPU models (safetensors that are not AMX) + moe_models = [] + for m in all_models: + if m.format == "safetensors": + is_amx_model, _ = is_amx_weights(m.path) + if not is_amx_model: + moe_models.append(m) + + if not moe_models: + print_error(t("model_edit_no_models")) + console.print() + console.print(f" {t('model_edit_add_hint_scan')} [cyan]kt model scan[/cyan]") + console.print(f" {t('model_edit_add_hint_add')} [cyan]kt model add [/cyan]") + console.print() + raise typer.Exit(1) + + # Display models table with # column + console.print() + console.print(f"[bold cyan]{t('model_edit_select_title')}[/bold cyan]") + console.print() + + table = Table(show_header=True, header_style="bold", show_lines=False) + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Format", style="dim") + table.add_column("Path", style="dim", overflow="fold") + + for i, model_item in enumerate(moe_models, 1): + table.add_row(str(i), model_item.name, model_item.format, model_item.path) + + console.print(table) + console.print() + + from rich.prompt import IntPrompt + + choice = IntPrompt.ask(t("model_edit_select_model"), default=1, show_choices=False) + + if choice < 1 or choice > len(moe_models): + print_error(t("model_edit_invalid_choice")) + raise typer.Exit(1) + + model = moe_models[choice - 1] + else: + # Load model by name + model = registry.get_model(name) + if not model: + print_error(t("model_edit_not_found", name=name)) + console.print() + console.print(f" {t('model_edit_list_hint')} [cyan]kt model list[/cyan]") + console.print() + raise typer.Exit(1) + + # Keep track of original values to detect changes + original_name = model.name + original_repo_type = model.repo_type + original_repo_id = model.repo_id + original_gpu_model_ids = model.gpu_model_ids.copy() if model.gpu_model_ids else None + + # Working copy for edits (not saved until user confirms) + edited_name = model.name + edited_repo_type = model.repo_type + edited_repo_id = model.repo_id + edited_gpu_model_ids = model.gpu_model_ids.copy() if model.gpu_model_ids else None + + has_changes = False + + while True: + # Display current configuration (show edited values) + console.print() + console.print(f"[bold cyan]{t('model_edit_current_config')}[/bold cyan]\n") + + # Format SHA256 status (from original model) + sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status) + + # Check if this is a CPU model (GGUF or AMX) + is_cpu_model = model.format == "gguf" + if not is_cpu_model and model.format == "safetensors": + is_amx, _ = is_amx_weights(model.path) + is_cpu_model = is_amx + + # Format GPU links info (for CPU models) + gpu_links_info = "" + if is_cpu_model and edited_gpu_model_ids: + gpu_names = [] + for gpu_id in edited_gpu_model_ids: + gpu_obj = registry.get_model_by_id(gpu_id) + if gpu_obj: + gpu_names.append(gpu_obj.name) + else: + gpu_names.append(f"[dim red]{gpu_id[:8]}... (deleted)[/dim red]") + gpu_links_info = f"\n[bold]{t('model_edit_gpu_links')}[/bold] {', '.join(gpu_names)}" + + content = f"""[bold]Name:[/bold] {edited_name} +[bold]Path:[/bold] {model.path} +[bold]Format:[/bold] {model.format} +[bold]Repo Type:[/bold] {edited_repo_type or '-'} +[bold]Repo ID:[/bold] {edited_repo_id or '-'} +[bold]SHA256:[/bold] {sha256_display}{gpu_links_info}""" + + panel = Panel(content, border_style="cyan", padding=(0, 1)) + console.print(panel) + console.print() + + # Check if there are any changes + has_changes = ( + edited_name != original_name + or edited_repo_type != original_repo_type + or edited_repo_id != original_repo_id + or edited_gpu_model_ids != original_gpu_model_ids + ) + + # Show menu + console.print(f"[bold]{t('model_edit_what_to_edit')}[/bold]") + console.print(" [1] " + t("model_edit_option_name")) + console.print(" [2] " + t("model_edit_option_repo")) + console.print(" [3] " + t("model_edit_option_delete")) + if is_cpu_model: + console.print(" [4] " + t("model_edit_manage_gpu_links")) + save_option = "5" + cancel_option = "6" + console.print( + f" [{save_option}] {t('model_edit_save_changes')}" + + ( + f" [cyan]{t('model_edit_has_changes')}[/cyan]" + if has_changes + else f" [dim]{t('model_edit_no_changes')}[/dim]" + ) + ) + console.print(f" [{cancel_option}] " + t("model_edit_option_cancel")) + console.print() + choice = Prompt.ask(t("model_edit_choice_prompt"), choices=["1", "2", "3", "4", "5", "6"], default="6") + else: + save_option = "4" + cancel_option = "5" + console.print( + f" [{save_option}] {t('model_edit_save_changes')}" + + ( + f" [cyan]{t('model_edit_has_changes')}[/cyan]" + if has_changes + else f" [dim]{t('model_edit_no_changes')}[/dim]" + ) + ) + console.print(f" [{cancel_option}] " + t("model_edit_option_cancel")) + console.print() + choice = Prompt.ask(t("model_edit_choice_prompt"), choices=["1", "2", "3", "4", "5"], default="5") + + if choice == "1": + # Edit name (update working copy only) + console.print() + new_name = Prompt.ask(t("model_edit_new_name"), default=edited_name) + + if new_name != edited_name: + # Check for conflict (excluding both original and edited names) + if new_name != original_name and registry.check_name_conflict(new_name, exclude_name=original_name): + print_error(t("model_edit_name_conflict", name=new_name)) + continue + + edited_name = new_name + console.print() + print_info(f"[dim]{t('model_edit_name_pending')}[/dim]") + + elif choice == "2": + # Edit repo configuration (update working copy only) + console.print() + console.print(t("model_edit_repo_type_prompt")) + console.print(" [1] HuggingFace") + console.print(" [2] ModelScope") + console.print(" [3] " + t("model_edit_repo_remove")) + console.print() + + repo_choice = Prompt.ask(t("model_edit_choice_prompt"), choices=["1", "2", "3"], default="3") + + if repo_choice == "3": + # Remove repo + edited_repo_type = None + edited_repo_id = None + console.print() + print_info(f"[dim]{t('model_edit_repo_remove_pending')}[/dim]") + else: + # Set repo + repo_type = "huggingface" if repo_choice == "1" else "modelscope" + example = "deepseek-ai/DeepSeek-V3" if repo_choice == "1" else "deepseek/DeepSeek-V3" + + current_default = edited_repo_id if edited_repo_id and edited_repo_type == repo_type else "" + repo_id = Prompt.ask( + t("model_edit_repo_id_prompt", example=example), + default=current_default if current_default else None, + ) + + edited_repo_type = repo_type + edited_repo_id = repo_id + console.print() + print_info(f"[dim]{t('model_edit_repo_update_pending')}[/dim]") + + elif choice == "3": + # Delete model + console.print() + console.print(f"[bold yellow]{t('model_edit_delete_warning')}[/bold yellow]") + console.print(f" {t('model_edit_delete_note')}") + console.print() + + if Confirm.ask(t("model_edit_delete_confirm", name=model.name), default=False): + registry.remove_model(model.name) + console.print() + print_success(t("model_edit_deleted", name=model.name)) + console.print() + return + else: + console.print() + print_info(t("model_edit_delete_cancelled")) + + elif choice == "4" and is_cpu_model: + # Manage GPU Links (only for CPU models) - update working copy + console.print() + console.print(f"[bold cyan]{t('model_edit_gpu_links_title', name=edited_name)}[/bold cyan]") + console.print() + + # Show current links (from edited values) + if edited_gpu_model_ids: + console.print(f"[bold]{t('model_edit_current_gpu_links')}[/bold]") + for i, gpu_id in enumerate(edited_gpu_model_ids, 1): + gpu_obj = registry.get_model_by_id(gpu_id) + if gpu_obj: + console.print(f" [{i}] {gpu_obj.name}") + else: + console.print(f" [{i}] [red]{gpu_id[:8]}... (deleted)[/red]") + console.print() + else: + console.print(f"[dim]{t('model_edit_no_gpu_links')}[/dim]") + console.print() + + console.print(f"{t('model_edit_gpu_options')}") + console.print(f" [1] {t('model_edit_gpu_add')}") + console.print(f" [2] {t('model_edit_gpu_remove')}") + console.print(f" [3] {t('model_edit_gpu_clear')}") + console.print(f" [4] {t('model_edit_gpu_back')}") + console.print() + + link_choice = Prompt.ask(t("model_edit_gpu_choose_option"), choices=["1", "2", "3", "4"], default="4") + + if link_choice == "1": + # Add GPU link + # Get all GPU models (safetensors that are not AMX) + all_models = registry.list_models() + available_gpu_models = [] + for m in all_models: + if m.format == "safetensors": + is_amx_model, _ = is_amx_weights(m.path) + if not is_amx_model: + available_gpu_models.append(m) + + if not available_gpu_models: + console.print() + console.print(f"[yellow]{t('model_edit_gpu_none_available')}[/yellow]") + console.print() + else: + console.print() + console.print(f"{t('model_edit_gpu_available_models')}") + for i, gpu_m in enumerate(available_gpu_models, 1): + already_linked = edited_gpu_model_ids and gpu_m.id in edited_gpu_model_ids + status = f" [dim]{t('model_edit_gpu_already_linked')}[/dim]" if already_linked else "" + console.print(f" [{i}] {gpu_m.name}{status}") + console.print() + + gpu_choice = Prompt.ask(t("model_edit_gpu_enter_number"), default="0") + try: + gpu_idx = int(gpu_choice) - 1 + if 0 <= gpu_idx < len(available_gpu_models): + selected_gpu = available_gpu_models[gpu_idx] + + # Add to edited_gpu_model_ids + current_ids = list(edited_gpu_model_ids) if edited_gpu_model_ids else [] + if selected_gpu.id not in current_ids: + current_ids.append(selected_gpu.id) + edited_gpu_model_ids = current_ids + console.print() + print_info(f"[dim]{t('model_edit_gpu_link_pending', name=selected_gpu.name)}[/dim]") + else: + console.print() + console.print(f"[yellow]{t('model_edit_gpu_already_exists')}[/yellow]") + else: + console.print() + console.print(f"[red]{t('model_edit_gpu_invalid_choice')}[/red]") + except ValueError: + console.print() + console.print(f"[red]{t('model_edit_gpu_invalid_input')}[/red]") + + elif link_choice == "2": + # Remove GPU link + if not edited_gpu_model_ids: + console.print() + console.print(f"[yellow]{t('model_edit_gpu_none_to_remove')}[/yellow]") + console.print() + else: + console.print() + console.print(f"{t('model_edit_gpu_choose_to_remove')}") + gpu_list = [] + for i, gpu_id in enumerate(edited_gpu_model_ids, 1): + gpu_obj = registry.get_model_by_id(gpu_id) + gpu_name = gpu_obj.name if gpu_obj else f"{gpu_id[:8]}... (deleted)" + gpu_list.append((gpu_id, gpu_name)) + console.print(f" [{i}] {gpu_name}") + console.print() + + remove_choice = Prompt.ask(t("model_edit_gpu_enter_to_remove"), default="0") + try: + remove_idx = int(remove_choice) - 1 + if 0 <= remove_idx < len(gpu_list): + removed_id, removed_name = gpu_list[remove_idx] + new_ids = [gid for gid in edited_gpu_model_ids if gid != removed_id] + edited_gpu_model_ids = new_ids if new_ids else None + console.print() + print_info(f"[dim]{t('model_edit_gpu_remove_pending', name=removed_name)}[/dim]") + else: + console.print() + console.print(f"[red]{t('model_edit_gpu_invalid_choice')}[/red]") + except ValueError: + console.print() + console.print(f"[red]{t('model_edit_gpu_invalid_input')}[/red]") + + elif link_choice == "3": + # Clear all GPU links + if not edited_gpu_model_ids: + console.print() + console.print(f"[yellow]{t('model_edit_gpu_none_to_clear')}[/yellow]") + console.print() + else: + if Confirm.ask(t("model_edit_gpu_clear_confirm"), default=False): + edited_gpu_model_ids = None + console.print() + print_info(f"[dim]{t('model_edit_gpu_clear_pending')}[/dim]") + else: + console.print() + print_info(t("model_edit_cancelled_short")) + + elif choice == save_option: + # Save changes + if not has_changes: + console.print() + print_info(f"[dim]{t('model_edit_no_changes_to_save')}[/dim]") + continue + + console.print() + console.print(f"[bold cyan]{t('model_edit_saving')}[/bold cyan]") + console.print() + + # Determine if repo info changed (for verification prompt) + repo_changed = (original_repo_id is None and edited_repo_id is not None) or ( + original_repo_id != edited_repo_id + ) + + # Build updates dict + updates = {} + if edited_name != original_name: + updates["name"] = edited_name + if edited_repo_type != original_repo_type: + updates["repo_type"] = edited_repo_type + if edited_repo_id != original_repo_id: + updates["repo_id"] = edited_repo_id + # Update SHA256 status when repo changes + if edited_repo_id is None: + updates["sha256_status"] = "no_repo" + else: + updates["sha256_status"] = "not_checked" + if edited_gpu_model_ids != original_gpu_model_ids: + updates["gpu_model_ids"] = edited_gpu_model_ids + + # Save to registry + registry.update_model(original_name, updates) + print_success(t("model_edit_saved")) + + # Update local model object + if "name" in updates: + model.name = edited_name + if "repo_type" in updates: + model.repo_type = edited_repo_type + if "repo_id" in updates: + model.repo_id = edited_repo_id + if "sha256_status" in updates: + model.sha256_status = updates["sha256_status"] + if "gpu_model_ids" in updates: + model.gpu_model_ids = edited_gpu_model_ids + + # Update original values for next iteration + original_name = edited_name + original_repo_type = edited_repo_type + original_repo_id = edited_repo_id + original_gpu_model_ids = edited_gpu_model_ids.copy() if edited_gpu_model_ids else None + + # Display updated configuration + console.print() + console.print(f"[bold cyan]{t('model_edit_updated_config')}[/bold cyan]\n") + + sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status) + gpu_links_info = "" + if is_cpu_model and model.gpu_model_ids: + gpu_names = [] + for gpu_id in model.gpu_model_ids: + gpu_obj = registry.get_model_by_id(gpu_id) + if gpu_obj: + gpu_names.append(gpu_obj.name) + else: + gpu_names.append(f"[dim red]{gpu_id[:8]}... (deleted)[/dim red]") + gpu_links_info = f"\n[bold]{t('model_edit_gpu_links')}[/bold] {', '.join(gpu_names)}" + + content = f"""[bold]Name:[/bold] {model.name} +[bold]Path:[/bold] {model.path} +[bold]Format:[/bold] {model.format} +[bold]Repo Type:[/bold] {model.repo_type or '-'} +[bold]Repo ID:[/bold] {model.repo_id or '-'} +[bold]SHA256:[/bold] {sha256_display}{gpu_links_info}""" + + panel = Panel(content, border_style="green", padding=(0, 1)) + console.print(panel) + console.print() + + # If repo changed, suggest verification + if repo_changed and model.repo_id: + console.print() + console.print(f"[bold yellow]{t('model_edit_repo_changed_warning')}[/bold yellow]") + console.print() + console.print(f" {t('model_edit_verify_hint')}") + console.print() + + return + + elif choice == cancel_option: + # Cancel + console.print() + if has_changes: + if Confirm.ask(f"[yellow]{t('model_edit_discard_changes')}[/yellow]", default=False): + print_info(t("model_edit_cancelled")) + console.print() + return + else: + # Go back to menu + continue + else: + print_info(t("model_edit_cancelled")) + console.print() + return + + +@app.command(name="info") +def info_model( + name: str = typer.Argument(..., help="Name of model to display"), +) -> None: + """Display detailed information about a model.""" + from rich.panel import Panel + from pathlib import Path + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.model_scanner import format_size + + registry = UserModelRegistry() + + # Load model + model = registry.get_model(name) + if not model: + print_error(t("model_info_not_found", name=name)) + console.print() + console.print(f" {t('model_info_list_hint')} [cyan]kt model list[/cyan]") + console.print() + raise typer.Exit(1) console.print() - if not matches: - print_warning(t("model_search_no_results", query=query)) + # Check if path exists + path_status = "[green]✓ Exists[/green]" if model.path_exists() else "[red]✗ Missing[/red]" + + # Format repo info + if model.repo_id: + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + repo_info = f"{repo_abbr}:{model.repo_id}" + else: + repo_info = "-" + + # Format SHA256 status + sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status) + + # Calculate folder size and list files if exists + moe_info = "" + amx_info = "" + + if model.path_exists(): + path_obj = Path(model.path) + try: + if model.format == "safetensors": + files = list(path_obj.glob("*.safetensors")) + + # Check for AMX weights + is_amx, numa_count = is_amx_weights(str(path_obj)) + if is_amx: + amx_info = f"\n[bold]AMX Format:[/bold] Yes (NUMA: {numa_count})" + else: + # Check for MOE model + try: + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model + + moe_result = analyze_moe_model(str(path_obj)) + if moe_result and moe_result.get("num_experts", 0) > 0: + moe_info = f""" +[bold]MoE Info:[/bold] + • Total Experts: {moe_result['num_experts']} + • Activated Experts: {moe_result['num_experts_per_tok']} experts/token + • Hidden Layers: {moe_result['num_hidden_layers']} + • Total Model Size: {moe_result['total_size_gb']:.2f} GB""" + except Exception: + pass # Not a MoE model or analysis failed + else: + files = list(path_obj.glob("*.gguf")) + + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_str = format_size(total_size) + file_count = len(files) + size_info = f"{size_str} ({file_count} files)" + + # List first few files + file_list = "\n".join([f" • {f.name}" for f in sorted(files)[:5]]) + if len(files) > 5: + file_list += f"\n ... and {len(files) - 5} more files" + except Exception as e: + size_info = f"Error calculating size: {e}" + file_list = "-" + else: + size_info = "-" + file_list = "[red]Path does not exist[/red]" + + # Format created/verified dates + from datetime import datetime + + try: + created_date = datetime.fromisoformat(model.created_at).strftime("%Y-%m-%d %H:%M:%S") + except: + created_date = model.created_at + + if model.last_verified: + try: + verified_date = datetime.fromisoformat(model.last_verified).strftime("%Y-%m-%d %H:%M:%S") + except: + verified_date = model.last_verified + else: + verified_date = "-" + + # Create detailed panel + content = f"""[bold]Name:[/bold] {model.name} +[bold]Path:[/bold] {model.path} +[bold]Path Status:[/bold] {path_status} +[bold]Format:[/bold] {model.format} +[bold]Size:[/bold] {size_info}{amx_info}{moe_info} +[bold]Repo Type:[/bold] {model.repo_type or '-'} +[bold]Repo ID:[/bold] {model.repo_id or '-'} +[bold]SHA256:[/bold] {sha256_display} +[bold]Created:[/bold] {created_date} +[bold]Last Verified:[/bold] {verified_date} + +[bold]Files:[/bold] +{file_list}""" + + panel = Panel(content, title=f"[cyan]Model Information: {model.name}[/cyan]", border_style="cyan", padding=(1, 2)) + console.print(panel) + console.print() + + +@app.command(name="remove") +def remove_model( + name: str = typer.Argument(..., help="Name of model to remove"), + yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"), +) -> None: + """Remove a model from the registry (does not delete files).""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + + # Check if model exists + model = registry.get_model(name) + if not model: + print_error(t("model_remove_not_found", name=name)) + console.print() + console.print(f" {t('model_remove_list_hint')} [cyan]kt model list[/cyan]") + console.print() + raise typer.Exit(1) + + console.print() + console.print(f"[bold yellow]{t('model_remove_warning')}[/bold yellow]") + console.print(f" {t('model_remove_note')}") + console.print(f" [dim]Path: {model.path}[/dim]") + console.print() + + # Check if this GPU model is linked by any CPU models + model_uuid = model.id + affected_cpu_models = [] + + # Only check for GPU models (safetensors that are not AMX) + if model.format == "safetensors": + is_amx, _ = is_amx_weights(model.path) + if not is_amx: + # This is a GPU model, check for CPU models that link to it + for m in registry.list_models(): + if m.gpu_model_ids and model_uuid in m.gpu_model_ids: + affected_cpu_models.append(m) + + # If there are affected CPU models, inform the user + if affected_cpu_models: + console.print(f"[yellow]This GPU model is linked by {len(affected_cpu_models)} CPU model(s):[/yellow]") + for cpu_model in affected_cpu_models: + console.print(f" • {cpu_model.name}") + console.print() + console.print(f"[dim]These links will be automatically removed.[/dim]") + console.print() + + # Confirm deletion + if not yes: + if not confirm(t("model_remove_confirm", name=name), default=False): + print_info(t("model_remove_cancelled")) + console.print() + return + + # Clean up references in CPU models before removing + if affected_cpu_models: + for cpu_model in affected_cpu_models: + # Remove this GPU model's UUID from the cpu_model's gpu_model_ids list + new_gpu_ids = [gid for gid in cpu_model.gpu_model_ids if gid != model_uuid] + registry.update_model(cpu_model.name, {"gpu_model_ids": new_gpu_ids if new_gpu_ids else None}) + + # Remove from registry + if registry.remove_model(name): + console.print() + print_success(t("model_removed", name=name)) + console.print() + else: + print_error(t("model_remove_failed", name=name)) + raise typer.Exit(1) + + +@app.command(name="refresh") +def refresh_models() -> None: + """Check all registered models and identify missing ones.""" + from rich.table import Table + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + models = registry.list_models() + + if not models: + print_warning(t("model_no_registered_models")) console.print() return - table = Table(title=t("model_search_results_title", query=query), show_header=True) - table.add_column(t("model_column_name"), style="cyan") - table.add_column(t("model_column_hf_repo"), style="dim") - table.add_column(t("model_column_aliases"), style="yellow") + console.print() + print_info(t("model_refresh_checking")) - for model in matches: - aliases = ", ".join(model.aliases[:3]) - if len(model.aliases) > 3: - aliases += f" +{len(model.aliases) - 3} more" - table.add_row(model.name, model.hf_repo, aliases) + # Refresh status + status = registry.refresh_status() + + # Check relationship integrity + broken_relationships = [] # [(cpu_model, gpu_uuid, gpu_name_or_none)] + for model in models: + if model.gpu_model_ids: + for gpu_uuid in model.gpu_model_ids: + gpu_obj = registry.get_model_by_id(gpu_uuid) + if not gpu_obj: + broken_relationships.append((model.name, gpu_uuid, None)) + elif not gpu_obj.path_exists(): + broken_relationships.append((model.name, gpu_uuid, gpu_obj.name)) + + console.print() + + # Show results + has_issues = status["missing"] or broken_relationships + + if not has_issues: + print_success(t("model_refresh_all_valid", count=len(models))) + console.print(f" {t('model_refresh_total', total=len(models))}") + console.print() + return + + # Show broken relationships + if broken_relationships: + print_warning(f"Found {len(broken_relationships)} broken GPU link(s)") + console.print() + + from rich.table import Table + + rel_table = Table(show_header=True, header_style="bold yellow") + rel_table.add_column("CPU Model", style="cyan") + rel_table.add_column("GPU Model", style="dim") + rel_table.add_column("Issue", style="red") + + for cpu_name, gpu_uuid, gpu_name in broken_relationships: + if gpu_name is None: + gpu_display = f"{gpu_uuid[:8]}..." + issue = "Deleted" + else: + gpu_display = gpu_name + issue = "Path Missing" + rel_table.add_row(cpu_name, gpu_display, issue) + + console.print(rel_table) + console.print() + console.print(f"[dim]Use [cyan]kt model edit [/cyan] to fix GPU links[/dim]") + console.print() + + if not status["missing"]: + # Only broken relationships, no missing models + return + + # Show missing models + print_warning(t("model_refresh_missing_found", count=len(status["missing"]))) + console.print() + + table = Table(show_header=True, header_style="bold") + table.add_column(t("model_column_name"), style="cyan") + table.add_column(t("model_column_path"), style="dim") + table.add_column(t("model_column_status"), justify="center") + + for model in models: + if model.name in status["missing"]: + status_text = "[red]✗ Missing[/red]" + else: + status_text = "[green]✓ Valid[/green]" + + table.add_row(model.name, model.path, status_text) console.print(table) console.print() + + # Suggest actions + console.print(f"[bold]{t('model_refresh_suggestions')}:[/bold]") + console.print(f" • {t('model_refresh_remove_hint')} [cyan]kt model remove [/cyan]") + console.print(f" • {t('model_refresh_rescan_hint')} [cyan]kt model scan[/cyan]") + console.print() + + +@app.command(name="verify") +def verify_model( + name: str = typer.Argument(None, help="Name of model to verify (interactive if not provided)"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed SHA256 comparison for each file"), +) -> None: + """Verify model integrity using SHA256 checksums with interactive repair.""" + from pathlib import Path + from rich.prompt import Prompt, Confirm + from rich.table import Table + from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn, MofNCompleteColumn + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.model_verifier import verify_model_integrity_with_progress, check_huggingface_connectivity + + registry = UserModelRegistry() + + # Helper function to display model selection table + def show_model_table(): + from kt_kernel.cli.utils.model_scanner import format_size + from pathlib import Path + + # Import MoE analyzer + analyze_moe_model = None + try: + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model + except ImportError: + pass + + all_models = registry.list_models() + + # Filter: only safetensors models with repo_id + verifiable_models = [m for m in all_models if m.repo_id and m.format == "safetensors"] + + if not verifiable_models: + print_warning(t("model_verify_all_no_repos")) + console.print() + console.print(f" {t('model_verify_all_config_hint')}") + console.print() + return None + + # Analyze MoE models + moe_results = {} + if analyze_moe_model: + for model in verifiable_models: + try: + result = analyze_moe_model(model.path, use_cache=True) + if result and result.get("num_experts", 0) > 0: + moe_results[model.name] = result + except Exception: + pass + + # Filter to only show MoE models + moe_verifiable_models = [m for m in verifiable_models if m.name in moe_results] + + if not moe_verifiable_models: + console.print() + console.print("[yellow]No MoE models with repo_id found for verification.[/yellow]") + console.print() + console.print( + f"[dim]Only MoE models can be verified. Use [cyan]kt model list[/cyan] to see all models.[/dim]" + ) + console.print() + return None + + console.print() + console.print("[bold]Select a MoE model to verify:[/bold]\n") + + table = Table(show_header=True, header_style="bold", show_lines=False) + table.add_column("#", justify="right", style="dim", width=4) + table.add_column(t("model_column_name"), style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Exps", justify="center", style="yellow") + table.add_column("Act", justify="center", style="green") + table.add_column(t("model_column_repo"), style="dim", overflow="fold") + table.add_column(t("model_column_sha256"), justify="center") + + for i, model in enumerate(moe_verifiable_models, 1): + # Calculate size + if model.path_exists(): + path_obj = Path(model.path) + try: + files = list(path_obj.glob("*.safetensors")) + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_display = format_size(total_size) + except: + size_display = "[dim]-[/dim]" + else: + size_display = "[dim]-[/dim]" + + # Get MoE info + moe_info = moe_results.get(model.name) + experts_display = f"[yellow]{moe_info['num_experts']}[/yellow]" if moe_info else "[dim]-[/dim]" + activated_display = f"[green]{moe_info['num_experts_per_tok']}[/green]" if moe_info else "[dim]-[/dim]" + + # Repo info + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + repo_display = f"{repo_abbr}:{model.repo_id}" + + # SHA256 status + status_icon = { + "not_checked": "[dim]○[/dim]", + "checking": "[yellow]◐[/yellow]", + "passed": "[green]✓[/green]", + "failed": "[red]✗[/red]", + "no_repo": "[dim]-[/dim]", + }.get(model.sha256_status, "[dim]?[/dim]") + + table.add_row( + str(i), + model.name, + model.path, + size_display, + experts_display, + activated_display, + repo_display, + status_icon, + ) + + console.print(table) + console.print() + console.print("[dim]SHA256 Status: ○ Not checked | ✓ Passed | ✗ Failed[/dim]") + console.print() + + return moe_verifiable_models + + # Main verification loop + # Track files to verify (None = all files, list = specific files for re-verification) + files_to_verify = None + + while True: + selected_model = None + + # If name provided directly, use it once then switch to interactive + if name: + selected_model = registry.get_model(name) + if not selected_model: + print_error(t("model_verify_not_found", name=name)) + console.print() + console.print(f" {t('model_verify_list_hint')} [cyan]kt model list[/cyan]") + console.print() + raise typer.Exit(1) + name = None # Clear so next loop is interactive + else: + # Show interactive selection + verifiable_models = show_model_table() + if not verifiable_models: + return + + choice = Prompt.ask("Enter model number to verify (or 'q' to quit)", default="1") + + if choice.lower() == "q": + return + + try: + idx = int(choice) - 1 + if 0 <= idx < len(verifiable_models): + selected_model = verifiable_models[idx] + # Reset files_to_verify when selecting a new model + files_to_verify = None + else: + print_error(f"Invalid selection: {choice}") + console.print() + continue + except ValueError: + print_error(f"Invalid input: {choice}") + console.print() + continue + + # Check model prerequisites + console.print() + + if not selected_model.repo_id: + print_warning(t("model_verify_no_repo", name=selected_model.name)) + console.print() + console.print(f" {t('model_verify_config_hint', name=selected_model.name)}") + console.print() + continue + + if not selected_model.path_exists(): + print_error(t("model_verify_path_missing", path=selected_model.path)) + console.print() + continue + + # Check HuggingFace connectivity and decide whether to use mirror + use_mirror = False + if selected_model.repo_type == "huggingface": + with console.status("[dim]Checking HuggingFace connectivity...[/dim]"): + is_accessible, message = check_huggingface_connectivity(timeout=5) + + if not is_accessible: + print_warning("HuggingFace Connection Failed") + console.print() + console.print(f" {message}") + console.print() + console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]") + console.print() + use_mirror = True + + # Perform verification with progress bar + if files_to_verify: + print_info(f"Re-verifying {len(files_to_verify)} repaired files: {selected_model.name}") + else: + print_info(f"Verifying: {selected_model.name}") + console.print(f" Repository: [yellow]{selected_model.repo_type}[/yellow]:{selected_model.repo_id}") + console.print(f" Local path: {selected_model.path}") + console.print() + + # Helper function to fetch remote hashes with timeout (using console.status like connectivity check) + def fetch_remote_hashes_with_timeout(repo_type, repo_id, use_mirror, timeout_seconds): + """Fetch remote hashes with timeout, returns (hashes_dict, timed_out).""" + from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError + from kt_kernel.cli.utils.model_verifier import fetch_model_sha256 + + def fetch_hashes(): + platform = "hf" if repo_type == "huggingface" else "ms" + return fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout_seconds) + + executor = ThreadPoolExecutor(max_workers=1) + try: + future = executor.submit(fetch_hashes) + hashes = future.result(timeout=timeout_seconds) + executor.shutdown(wait=False) + return (hashes, False) + except (FutureTimeoutError, Exception): + executor.shutdown(wait=False) + return (None, True) + + # Step 1: Fetch remote hashes with timeout and fallback + official_hashes = None + + if selected_model.repo_type == "huggingface": + # HF fallback chain: HF → HF-mirror → MS + + # Try 1: HuggingFace (or HF-mirror if already set) + status = console.status( + "[dim]Fetching remote hashes from HuggingFace{}...[/dim]".format(" mirror" if use_mirror else "") + ) + status.start() + official_hashes, timed_out = fetch_remote_hashes_with_timeout( + repo_type="huggingface", repo_id=selected_model.repo_id, use_mirror=use_mirror, timeout_seconds=10 + ) + status.stop() + + # Try 2: If timed out and not already using mirror, try HF-mirror + if timed_out and not use_mirror: + print_warning("HuggingFace Fetch Timeout (10s)") + console.print() + console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]") + console.print() + + status = console.status("[dim]Fetching remote hashes from HuggingFace mirror...[/dim]") + status.start() + official_hashes, timed_out = fetch_remote_hashes_with_timeout( + repo_type="huggingface", + repo_id=selected_model.repo_id, + use_mirror=True, # Use mirror + timeout_seconds=10, + ) + status.stop() + + # Try 3: If still timed out, try ModelScope with same repo_id + if timed_out: + print_warning("HuggingFace Mirror Timeout (10s)") + console.print() + console.print(" [yellow]Fallback to ModelScope mirror with same repo_id...[/yellow]") + console.print() + + status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]") + status.start() + official_hashes, timed_out = fetch_remote_hashes_with_timeout( + repo_type="modelscope", + repo_id=selected_model.repo_id, # Use same repo_id + use_mirror=False, + timeout_seconds=10, + ) + status.stop() + + if official_hashes: + # Success with ModelScope + console.print(" [green]✓ Successfully fetched from ModelScope[/green]") + console.print() + elif timed_out: + # All failed + print_error("All sources timed out (HuggingFace and ModelScope)") + console.print() + console.print(" Please check your network connection or try again later") + console.print() + continue + + elif selected_model.repo_type == "modelscope": + # ModelScope: no fallback, just timeout + status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]") + status.start() + official_hashes, timed_out = fetch_remote_hashes_with_timeout( + repo_type="modelscope", repo_id=selected_model.repo_id, use_mirror=False, timeout_seconds=10 + ) + status.stop() + + if timed_out: + print_error("ModelScope Fetch Timeout (10s)") + console.print() + console.print(" Please check your network connection or try again later") + console.print() + continue + + # Check if we successfully fetched remote hashes + if not official_hashes: + # Already printed error message above, skip to next model + continue + + # Success - print confirmation + console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes from remote[/green]") + console.print() + + # Step 2 & 3: Calculate local SHA256 and compare (with Progress bar) + from kt_kernel.cli.utils.model_verifier import calculate_local_sha256 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + console=console, + ) as progress: + # Step 2: Calculate local SHA256 hashes (no timeout) + local_dir_path = Path(selected_model.path) + + # Determine which files to hash + if files_to_verify: + # Only hash files that need re-verification + clean_filenames = { + Path(f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()).name + for f in files_to_verify + } + # Collect files matching *.safetensors, *.json, *.py + files_to_hash = [] + for pattern in ["*.safetensors", "*.json", "*.py"]: + files_to_hash.extend( + [f for f in local_dir_path.glob(pattern) if f.is_file() and f.name in clean_filenames] + ) + else: + # Collect all important files: *.safetensors, *.json, *.py + files_to_hash = [] + for pattern in ["*.safetensors", "*.json", "*.py"]: + files_to_hash.extend([f for f in local_dir_path.glob(pattern) if f.is_file()]) + + total_files = len(files_to_hash) + + # Create progress task for local hashing + hash_task_id = progress.add_task("[yellow]Calculating local SHA256...", total=total_files) + completed_count = [0] + + def local_hash_callback(msg: str): + if "Using" in msg and "workers" in msg: + # Show parallel worker info + console.print(f" [dim]{msg}[/dim]") + elif "[" in msg and "/" in msg and "]" in msg: + # Progress update + completed_count[0] += 1 + if "✓" in msg: + filename = msg.split("✓")[1].strip().split("(")[0].strip() + progress.update(hash_task_id, advance=1, description=f"[yellow]Hashing: {filename[:40]}...") + + local_hashes = calculate_local_sha256( + local_dir_path, + "*.safetensors", + progress_callback=local_hash_callback, + files_list=files_to_hash if files_to_verify else None, + ) + + progress.remove_task(hash_task_id) + console.print(f" [green]✓ Calculated {len(local_hashes)} local file hashes[/green]") + + # Step 3: Compare hashes + # If re-verifying specific files, only compare those files + if files_to_verify: + # Build set of clean filenames to verify + clean_verify_filenames = { + Path(f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()).name + for f in files_to_verify + } + # Filter official_hashes to only include files we're re-verifying + hashes_to_compare = { + filename: hash_value + for filename, hash_value in official_hashes.items() + if Path(filename).name in clean_verify_filenames + } + else: + # First-time verification: compare all files + hashes_to_compare = official_hashes + + compare_task_id = progress.add_task("[blue]Comparing hashes...", total=len(hashes_to_compare)) + + files_failed = [] + files_missing = [] + files_passed = 0 + + for filename, official_hash in hashes_to_compare.items(): + file_basename = Path(filename).name + + # Find matching local file + local_hash = None + for local_file, local_hash_value in local_hashes.items(): + if Path(local_file).name == file_basename: + local_hash = local_hash_value + break + + if local_hash is None: + files_missing.append(filename) + if verbose: + console.print(f" [red]✗ {file_basename} (missing)[/red]") + elif local_hash.lower() != official_hash.lower(): + files_failed.append(f"{filename} (hash mismatch)") + if verbose: + console.print(f" [red]✗ {file_basename} (hash mismatch)[/red]") + else: + files_passed += 1 + if verbose: + console.print(f" [green]✓ {file_basename}[/green]") + + progress.update(compare_task_id, advance=1) + + progress.remove_task(compare_task_id) + + # Build result + total_checked = len(hashes_to_compare) # Use actual compared count + if files_failed or files_missing: + all_failed = files_failed + [f"{f} (missing)" for f in files_missing] + result = { + "status": "failed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": all_failed, + } + else: + result = { + "status": "passed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": [], + } + + # Update registry status and display results + if result["status"] == "passed": + registry.update_model(selected_model.name, {"sha256_status": "passed"}) + console.print() + print_success(t("model_verify_passed")) + console.print() + console.print(f" ✓ Files checked: [bold green]{result['files_checked']}[/bold green]") + console.print(f" ✓ All files passed SHA256 verification") + console.print() + elif result["status"] == "failed": + registry.update_model(selected_model.name, {"sha256_status": "failed"}) + console.print() + print_error(f"Verification failed! {len(result['files_failed'])} file(s) have issues") + console.print() + console.print(f" Total files: {result['files_checked']}") + console.print(f" ✓ Passed: [green]{result['files_passed']}[/green]") + console.print(f" ✗ Failed: [red]{len(result['files_failed'])}[/red]") + console.print() + + # Show failed files (only if not already shown in verbose mode) + if not verbose: + console.print(" [bold red]Failed files:[/bold red]") + for failed_file in result["files_failed"]: + console.print(f" ✗ {failed_file}") + console.print() + + # Ask if user wants to repair + if Confirm.ask("[yellow]Do you want to repair (re-download) the failed files?[/yellow]", default=True): + console.print() + print_info("Repairing failed files...") + + # Extract clean filenames by removing status suffixes + files_to_download = [ + f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip() for f in result["files_failed"] + ] + + # Download each failed file + success_count = 0 + + # Set mirror for downloads if needed + import os + + original_hf_endpoint = os.environ.get("HF_ENDPOINT") + if use_mirror and selected_model.repo_type == "huggingface" and not original_hf_endpoint: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + console.print(f" [dim]Using HuggingFace mirror for downloads[/dim]") + + try: + for file_to_repair in files_to_download: + console.print(f" Repairing: [cyan]{file_to_repair}[/cyan]") + + # Step 1: Delete the corrupted/missing file if it exists + local_file_path = Path(selected_model.path) / file_to_repair + if local_file_path.exists(): + try: + local_file_path.unlink() + console.print(f" [dim]✓ Deleted corrupted file[/dim]") + except Exception as e: + console.print(f" [yellow]⚠ Could not delete file: {e}[/yellow]") + + # Step 2: Download the fresh file + if selected_model.repo_type == "huggingface": + # Use hf_hub_download for HuggingFace (inherits HF_ENDPOINT env var) + try: + from huggingface_hub import hf_hub_download + + hf_hub_download( + repo_id=selected_model.repo_id, + filename=file_to_repair, + local_dir=selected_model.path, + local_dir_use_symlinks=False, + ) + console.print(f" [green]✓ Downloaded successfully[/green]") + success_count += 1 + except ImportError: + print_error("huggingface_hub not installed. Install: pip install huggingface_hub") + break + except Exception as e: + console.print(f" [red]✗ Download failed: {e}[/red]") + else: + # Use modelscope download for ModelScope + try: + from modelscope.hub.snapshot_download import snapshot_download + + # Download directly to local_dir + snapshot_download( + model_id=selected_model.repo_id, + local_dir=selected_model.path, + allow_file_pattern=file_to_repair, + ) + console.print(f" [green]✓ Downloaded successfully[/green]") + success_count += 1 + except ImportError: + print_error("modelscope not installed. Install: pip install modelscope") + break + except Exception as e: + console.print(f" [red]✗ Download failed: {e}[/red]") + finally: + # Restore original HF_ENDPOINT + if use_mirror and selected_model.repo_type == "huggingface" and not original_hf_endpoint: + os.environ.pop("HF_ENDPOINT", None) + elif original_hf_endpoint: + os.environ["HF_ENDPOINT"] = original_hf_endpoint + + console.print() + if success_count > 0: + print_success(f"Repaired {success_count}/{len(files_to_download)} files") + console.print() + + # Ask if user wants to re-verify + if Confirm.ask("Re-verify the model now?", default=True): + # Re-verify by continuing the loop with the same model + # Only verify the files that were repaired + name = selected_model.name + files_to_verify = files_to_download + continue + + +@app.command(name="verify-all") +def verify_all_models() -> None: + """Verify all models with repo configuration (not yet implemented).""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + models = registry.list_models() + + # Filter models with repo configuration + models_with_repo = [m for m in models if m.repo_id] + + if not models_with_repo: + print_warning(t("model_verify_all_no_repos")) + console.print() + console.print(f" {t('model_verify_all_config_hint')} [cyan]kt model edit [/cyan]") + console.print() + return + + console.print() + print_warning(t("model_verify_not_implemented")) + console.print() + console.print(f" {t('model_verify_all_found', count=len(models_with_repo))}") + console.print() + + for model in models_with_repo: + console.print(f" • {model.name} ({model.repo_type}:{model.repo_id})") + + console.print() + console.print(f" [dim]{t('model_verify_future_note')}[/dim]") + console.print() + console.print(f" {t('model_verify_all_manual_hint')} [cyan]kt model verify [/cyan]") + console.print() + + +@app.command(name="auto-repo") +def auto_detect_repo( + apply: bool = typer.Option( + False, "--apply", "-a", help="Automatically apply detected repo information without confirmation" + ), + dry_run: bool = typer.Option( + False, "--dry-run", "-d", help="Show what would be detected without making any changes" + ), +) -> None: + """ + Auto-detect repository information from model README.md files. + + Scans all models without repo_id (safetensors/gguf only) and attempts to + extract repository information from README.md metadata (license_link field). + + Examples: + kt model auto-repo # Scan and ask for confirmation + kt model auto-repo --apply # Scan and apply automatically + kt model auto-repo --dry-run # Scan only, no changes + """ + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.repo_detector import scan_models_for_repo, format_detection_report, apply_detection_results + from rich.table import Table + + console.print() + print_info("Scanning models for repository information...") + console.print() + + # Get all models + registry = UserModelRegistry() + models = registry.list_models() + + if not models: + print_warning("No models found in registry") + console.print() + return + + # Scan for repo information + print_step("Analyzing README.md files...") + results = scan_models_for_repo(models) + + # Show results + console.print() + + if not results["detected"] and not results["not_detected"]: + print_info("All models already have repository information configured") + console.print() + return + + # Create results table + if results["detected"]: + console.print("[bold green]✓ Detected Repository Information[/bold green]") + console.print() + + table = Table(show_header=True, header_style="bold cyan") + table.add_column("Model Name", style="yellow") + table.add_column("Repository", style="cyan") + table.add_column("Type", style="magenta") + + for model, repo_id, repo_type in results["detected"]: + table.add_row(model.name, repo_id, repo_type) + + console.print(table) + console.print() + + if results["not_detected"]: + console.print( + f"[bold yellow]✗ No Repository Information Found ({len(results['not_detected'])} models)[/bold yellow]" + ) + console.print() + + for model in results["not_detected"][:5]: # Show first 5 + console.print(f" • {model.name}") + + if len(results["not_detected"]) > 5: + console.print(f" ... and {len(results['not_detected']) - 5} more") + + console.print() + + if results["skipped"]: + console.print( + f"[dim]⊘ Skipped {len(results['skipped'])} models (already configured or not safetensors/gguf)[/dim]" + ) + console.print() + + # Summary + console.print("[bold]Summary:[/bold]") + console.print(f" • [green]{len(results['detected'])}[/green] detected") + console.print(f" • [yellow]{len(results['not_detected'])}[/yellow] not detected") + console.print(f" • [dim]{len(results['skipped'])}[/dim] skipped") + console.print() + + # Exit if dry run or no detections + if dry_run: + print_info("Dry run mode - no changes made") + console.print() + return + + if not results["detected"]: + console.print() + return + + # Ask for confirmation (unless --apply flag) + if not apply: + console.print() + if not confirm(f"Apply repository information to {len(results['detected'])} model(s)?", default=False): + print_warning("Cancelled - no changes made") + console.print() + return + + # Apply changes + console.print() + print_step("Applying changes...") + + updated_count = apply_detection_results(results, registry) + + console.print() + if updated_count > 0: + print_success(f"✓ Updated {updated_count} model(s) with repository information") + console.print() + console.print(" You can now:") + console.print(" • Run [cyan]kt model verify [/cyan] to verify model integrity") + console.print(" • Check status with [cyan]kt model list[/cyan]") + console.print() + else: + print_error("Failed to update models") + console.print() diff --git a/kt-kernel/python/cli/commands/quant.py b/kt-kernel/python/cli/commands/quant.py index c6cf2c3..d97961e 100644 --- a/kt-kernel/python/cli/commands/quant.py +++ b/kt-kernel/python/cli/commands/quant.py @@ -35,12 +35,12 @@ class QuantMethod(str, Enum): def quant( - model: str = typer.Argument( - ..., + model: Optional[str] = typer.Argument( + None, help="Model name or path to quantize", ), - method: QuantMethod = typer.Option( - QuantMethod.INT4, + method: Optional[QuantMethod] = typer.Option( + None, "--method", "-m", help="Quantization method", @@ -51,8 +51,8 @@ def quant( "-o", help="Output path for quantized weights", ), - input_type: str = typer.Option( - "fp8", + input_type: Optional[str] = typer.Option( + None, "--input-type", "-i", help="Input weight type (fp8, fp16, bf16)", @@ -72,6 +72,11 @@ def quant( "--no-merge", help="Don't merge safetensor files", ), + gpu: bool = typer.Option( + False, + "--gpu", + help="Use GPU for conversion (faster)", + ), yes: bool = typer.Option( False, "--yes", @@ -79,54 +84,231 @@ def quant( help="Skip confirmation prompts", ), ) -> None: - """Quantize model weights for CPU inference.""" - settings = get_settings() - console.print() + """Quantize model weights for CPU inference. - # Resolve input path - input_path = _resolve_input_path(model, settings) - if input_path is None: - print_error(t("quant_input_not_found", path=model)) + If no model is specified, interactive mode will be activated. + """ + settings = get_settings() + + # Check if we should use interactive mode + # Interactive mode triggers when: no model, or missing critical parameters + needs_interactive = model is None or method is None or cpu_threads is None or numa_nodes is None + is_interactive = False + + if needs_interactive and sys.stdin.isatty(): + # Use interactive configuration (includes verification in Step 1.5) + from kt_kernel.cli.utils.quant_interactive import interactive_quant_config + + console.print() + console.print(f"[bold cyan]═══ {t('quant_interactive_title')} ═══[/bold cyan]") + console.print() + console.print(f"[yellow]{t('quant_new_model_notice')}[/yellow]") + console.print() + + config = interactive_quant_config() + if config is None: + # User cancelled + raise typer.Exit(0) + + # Extract configuration + model_obj = config["model"] + model = model_obj.id + input_path = Path(model_obj.path) + method = QuantMethod(config["method"]) + input_type = config["input_type"] + cpu_threads = config["cpu_threads"] + numa_nodes = config["numa_nodes"] + output = config["output_path"] + gpu = config["use_gpu"] + is_interactive = True + + console.print() + print_success(t("quant_config_complete")) + console.print() + else: + # Non-interactive mode - require model parameter + if model is None: + print_error("Model argument is required in non-interactive mode") + console.print() + console.print("Usage: kt quant ") + console.print(" Or: kt quant (for interactive mode)") + raise typer.Exit(1) + + # Set defaults for optional parameters + method = method or QuantMethod.INT4 + input_type = input_type or "fp8" + + console.print() + + # Resolve input path + input_path = _resolve_input_path(model, settings) + if input_path is None: + print_error(t("quant_input_not_found", path=model)) + raise typer.Exit(1) + + # Pre-quantization verification (only in non-interactive mode) + # Interactive mode already did verification in interactive_quant_config() + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.model_verifier import pre_operation_verification + + user_registry = UserModelRegistry() + user_model_obj = user_registry.find_by_path(str(input_path)) + + if user_model_obj and user_model_obj.format == "safetensors": + pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing") + + # Get user model info for both modes (needed later for registering quantized model) + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + user_registry = UserModelRegistry() + user_model_obj = user_registry.find_by_path(str(input_path)) + + # Validate that it's a MoE model (not AMX or GGUF) + from kt_kernel.cli.commands.model import is_amx_weights + + # Check if it's AMX (already quantized) + is_amx, _ = is_amx_weights(str(input_path)) + if is_amx: + print_error("Cannot quantize AMX models (already quantized)") + console.print() + console.print(f" The model at {input_path} is already in AMX format.") raise typer.Exit(1) - print_info(t("quant_input_path", path=str(input_path))) + # Check if it's a MoE model + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model - # Resolve output path - if output is None: - output = input_path.parent / f"{input_path.name}-{method.value.upper()}" - - print_info(t("quant_output_path", path=str(output))) - print_info(t("quant_method", method=method.value.upper())) - - # Detect CPU configuration - cpu = detect_cpu_info() - final_cpu_threads = cpu_threads or cpu.cores - final_numa_nodes = numa_nodes or cpu.numa_nodes - - print_info(f"CPU threads: {final_cpu_threads}") - print_info(f"NUMA nodes: {final_numa_nodes}") - - # Check if output exists - if output.exists(): - print_warning(f"Output path already exists: {output}") + moe_result = None # Store for later use when registering quantized model + try: + moe_result = analyze_moe_model(str(input_path), use_cache=True) + if not moe_result or not moe_result.get("is_moe"): + print_error("Only MoE models can be quantized to AMX format") + console.print() + console.print(f" The model at {input_path} is not a MoE model.") + console.print(" AMX quantization is designed for MoE models (e.g., DeepSeek-V3).") + raise typer.Exit(1) + except Exception as e: + print_warning(f"Could not detect MoE information: {e}") + console.print() if not yes: - if not confirm("Overwrite?", default=False): + if not confirm("Continue quantization anyway?", default=False): + raise typer.Exit(1) + + # Detect CPU configuration and resolve output path (only needed in non-interactive mode) + if not is_interactive: + print_info(t("quant_input_path", path=str(input_path))) + + # Detect CPU configuration (needed for output path) + cpu = detect_cpu_info() + final_cpu_threads = cpu_threads or cpu.cores + final_numa_nodes = numa_nodes or cpu.numa_nodes + + # Resolve output path + if output is None: + # Priority: paths.weights > paths.models[0] > model's parent directory + weights_dir = settings.weights_dir + + if weights_dir and weights_dir.exists(): + # Use configured weights directory (highest priority) + output = weights_dir / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}" + else: + # Use first model storage path + model_paths = settings.get_model_paths() + if model_paths and model_paths[0].exists(): + output = model_paths[0] / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}" + else: + # Fallback to model's parent directory + output = input_path.parent / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}" + + print_info(t("quant_output_path", path=str(output))) + print_info(t("quant_method", method=method.value.upper())) + print_info(t("quant_cpu_threads", threads=final_cpu_threads)) + print_info(t("quant_numa_nodes", nodes=final_numa_nodes)) + + # Calculate space requirements + console.print() + console.print(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]") + console.print() + + # Calculate source model size + try: + total_bytes = sum(f.stat().st_size for f in input_path.glob("*.safetensors") if f.is_file()) + source_size_gb = total_bytes / (1024**3) + except Exception: + source_size_gb = 0.0 + + # Estimate quantized size + input_bits = {"fp8": 8, "fp16": 16, "bf16": 16} + quant_bits = {"int4": 4, "int8": 8} + input_bit = input_bits.get(input_type, 16) + quant_bit = quant_bits.get(method.value, 4) + ratio = quant_bit / input_bit + estimated_size_gb = source_size_gb * ratio + + # Check available space + import shutil + + try: + check_path = output.parent if not output.exists() else output + while not check_path.exists() and check_path != check_path.parent: + check_path = check_path.parent + stat = shutil.disk_usage(check_path) + available_gb = stat.free / (1024**3) + except Exception: + available_gb = 0.0 + + is_sufficient = available_gb >= (estimated_size_gb * 1.2) + + console.print(f" {t('quant_source_size'):<26} {source_size_gb:.2f} GB") + console.print(f" {t('quant_estimated_size'):<26} {estimated_size_gb:.2f} GB") + console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB") + console.print() + + if not is_sufficient: + required_with_buffer = estimated_size_gb * 1.2 + print_warning(t("quant_insufficient_space")) + console.print() + console.print(f" {t('quant_required_space'):<26} {required_with_buffer:.2f} GB") + console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB") + console.print(f" {t('quant_shortage'):<26} {required_with_buffer - available_gb:.2f} GB") + console.print() + console.print(f" {t('quant_may_fail')}") + console.print() + + if not yes: + if not confirm(t("quant_continue_anyway"), default=False): + raise typer.Abort() + console.print() + + # Check if output exists and generate unique name + if output.exists(): + print_warning(t("quant_output_exists", path=str(output))) + console.print() + + # Generate unique name by adding suffix + original_name = output.name + parent_dir = output.parent + counter = 2 + + while output.exists(): + new_name = f"{original_name}-{counter}" + output = parent_dir / new_name + counter += 1 + + print_success(t("quant_using_unique", path=str(output))) + console.print() + + # Confirm (only show if not using --yes flag) + if not yes: + console.print() + print_warning(t("quant_time_warning")) + console.print() + + if not confirm(t("prompt_continue")): raise typer.Abort() - - # Confirm - if not yes: - console.print() - console.print("[bold]Quantization Settings:[/bold]") - console.print(f" Input: {input_path}") - console.print(f" Output: {output}") - console.print(f" Method: {method.value.upper()}") - console.print(f" Input type: {input_type}") - console.print() - print_warning("Quantization may take 30-60 minutes depending on model size.") - console.print() - - if not confirm(t("prompt_continue")): - raise typer.Abort() + else: + # Interactive mode: cpu_threads and numa_nodes already set + final_cpu_threads = cpu_threads + final_numa_nodes = numa_nodes # Find conversion script kt_kernel_path = _find_kt_kernel_path() @@ -141,37 +323,145 @@ def quant( # Build command cmd = [ - sys.executable, str(script_path), - "--input-path", str(input_path), - "--input-type", input_type, - "--output", str(output), - "--quant-method", method.value, - "--cpuinfer-threads", str(final_cpu_threads), - "--threadpool-count", str(final_numa_nodes), + sys.executable, + str(script_path), + "--input-path", + str(input_path), + "--input-type", + input_type, + "--output", + str(output), + "--quant-method", + method.value, + "--cpuinfer-threads", + str(final_cpu_threads), + "--threadpool-count", + str(final_numa_nodes), ] if no_merge: cmd.append("--no-merge-safetensor") + if gpu: + cmd.append("--gpu") + # Run quantization console.print() print_step(t("quant_starting")) console.print() console.print(f"[dim]$ {' '.join(cmd)}[/dim]") console.print() + console.print("[dim]" + "=" * 80 + "[/dim]") + console.print() try: - process = subprocess.run(cmd) + # Run with real-time stdout/stderr output + import os + import time + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" # Disable Python output buffering + + # Record start time + start_time = time.time() + + process = subprocess.run( + cmd, + stdout=None, # Inherit parent's stdout (real-time output) + stderr=None, # Inherit parent's stderr (real-time output) + env=env, + ) + + # Calculate elapsed time + elapsed_time = time.time() - start_time + hours = int(elapsed_time // 3600) + minutes = int((elapsed_time % 3600) // 60) + seconds = int(elapsed_time % 60) + + console.print() + console.print("[dim]" + "=" * 80 + "[/dim]") + console.print() if process.returncode == 0: - console.print() print_success(t("quant_complete")) console.print() + + # Display elapsed time + if hours > 0: + time_str = f"{hours}h {minutes}m {seconds}s" + elif minutes > 0: + time_str = f"{minutes}m {seconds}s" + else: + time_str = f"{seconds}s" + console.print(f" [cyan]{t('quant_time_elapsed')} {time_str}[/cyan]") + console.print() console.print(f" Quantized weights saved to: {output}") console.print() - console.print(" Use with:") - console.print(f" kt run {model} --weights-path {output}") - console.print() + + # Auto-register the quantized model + try: + from kt_kernel.cli.utils.user_model_registry import UserModel + + # Generate model name from output path + base_name = output.name + suggested_name = user_registry.suggest_name(base_name) + + # Determine MoE information and source model name + if user_model_obj: + is_moe_val = user_model_obj.is_moe + num_experts = user_model_obj.moe_num_experts + num_active = user_model_obj.moe_num_experts_per_tok + repo_type_val = user_model_obj.repo_type + repo_id_val = user_model_obj.repo_id + source_model_name = user_model_obj.name # Store source model name + elif moe_result: + is_moe_val = moe_result.get("is_moe", True) + num_experts = moe_result.get("num_experts") + num_active = moe_result.get("num_experts_per_tok") + repo_type_val = None + repo_id_val = None + source_model_name = input_path.name # Use folder name as fallback + else: + is_moe_val = None + num_experts = None + num_active = None + repo_type_val = None + repo_id_val = None + source_model_name = input_path.name # Use folder name as fallback + + # Create new model entry (AMX format uses "safetensors" format, detected by is_amx_weights()) + new_model = UserModel( + name=suggested_name, + path=str(output), + format="safetensors", # AMX files are safetensors format + repo_type=repo_type_val, + repo_id=repo_id_val, + sha256_status="not_checked", # AMX weights don't need verification + # Inherit MoE information from source model + is_moe=is_moe_val, + moe_num_experts=num_experts, + moe_num_experts_per_tok=num_active, + # AMX quantization metadata + amx_source_model=source_model_name, + amx_quant_method=method.value, # "int4" or "int8" + amx_numa_nodes=final_numa_nodes, + ) + + user_registry.add_model(new_model) + console.print() + print_success(t("quant_registered", name=suggested_name)) + console.print() + console.print(f" {t('quant_view_with')} [cyan]kt model list[/cyan]") + console.print(f" {t('quant_use_with')} [cyan]kt run {suggested_name}[/cyan]") + console.print() + except Exception as e: + # Non-fatal error - quantization succeeded but registration failed + console.print() + print_warning(t("quant_register_failed", error=str(e))) + console.print() + console.print(f" {t('quant_use_with')}") + console.print(f" kt run {model} --weights-path {output}") + console.print() else: print_error(f"Quantization failed with exit code {process.returncode}") raise typer.Exit(process.returncode) @@ -221,6 +511,7 @@ def _find_kt_kernel_path() -> Optional[Path]: """Find the kt-kernel installation path.""" try: import kt_kernel + return Path(kt_kernel.__file__).parent.parent except ImportError: pass diff --git a/kt-kernel/python/cli/commands/run.py b/kt-kernel/python/cli/commands/run.py index b2bb475..08ad216 100644 --- a/kt-kernel/python/cli/commands/run.py +++ b/kt-kernel/python/cli/commands/run.py @@ -28,7 +28,7 @@ from kt_kernel.cli.utils.console import ( prompt_choice, ) from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb -from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry +from kt_kernel.cli.utils.user_model_registry import UserModelRegistry @click.command( @@ -120,8 +120,6 @@ def run( # Handle disable/enable shared experts fusion flags if enable_shared_experts_fusion: disable_shared_experts_fusion = False - elif disable_shared_experts_fusion is None: - disable_shared_experts_fusion = None # Convert Path objects from click model_path_obj = Path(model_path) if model_path else None @@ -214,266 +212,250 @@ def _run_impl( raise typer.Exit(1) settings = get_settings() - registry = get_registry() + user_registry = UserModelRegistry() - console.print() + # Check if we should use interactive mode + # Interactive mode triggers when: + # 1. No model specified, OR + # 2. Model specified but missing critical parameters (gpu_experts, tensor_parallel_size, etc.) + use_interactive = False - # If no model specified, show interactive selection if model is None: - model = _interactive_model_selection(registry, settings) - if model is None: + use_interactive = True + elif ( + gpu_experts is None + or tensor_parallel_size is None + or cpu_threads is None + or numa_nodes is None + or max_total_tokens is None + ): + # Model specified but some parameters missing - use interactive + use_interactive = True + + if use_interactive and sys.stdin.isatty(): + # Use new interactive configuration flow + from kt_kernel.cli.utils.run_interactive import interactive_run_config + + console.print() + console.print("[bold cyan]═══ Interactive Run Configuration ═══[/bold cyan]") + console.print() + + config = interactive_run_config() + if config is None: + # User cancelled raise typer.Exit(0) - # Step 1: Detect hardware - print_step(t("run_detecting_hardware")) - gpus = detect_gpus() - cpu = detect_cpu_info() - ram = detect_ram_gb() + # Extract configuration from new format + user_model_obj = config["model"] + model = user_model_obj.id + resolved_model_path = Path(config["model_path"]) + resolved_weights_path = Path(config["weights_path"]) - if gpus: - gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)" - if len(gpus) > 1: - gpu_info += f" + {len(gpus) - 1} more" - print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb)) + # Extract parameters + gpu_experts = config["gpu_experts"] + cpu_threads = config["cpu_threads"] + numa_nodes = config["numa_nodes"] + tensor_parallel_size = config["tp_size"] + + # Get kt-method and other method-specific settings + kt_method = config["kt_method"] + + # KV cache settings (may be None for non-raw methods) + max_total_tokens = config.get("kv_cache", 32768) + chunked_prefill_size = config.get("chunk_prefill", 32768) + kt_gpu_prefill_threshold = config.get("gpu_prefill_threshold", 500) + + # Memory settings + mem_fraction_static = config["mem_fraction_static"] + + # Parser settings (optional) + tool_call_parser = config.get("tool_call_parser") + reasoning_parser = config.get("reasoning_parser") + + # Server settings + host = config.get("host", "0.0.0.0") + port = config.get("port", 30000) + + # Set CUDA_VISIBLE_DEVICES for selected GPUs + selected_gpus = config["selected_gpus"] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id) for gpu_id in selected_gpus) + + # Detect hardware for parameter resolution (needed for resolve() function later) + gpus = detect_gpus() + cpu = detect_cpu_info() + + console.print() + print_info(f"[green]✓[/green] Configuration complete") + console.print() else: - print_warning(t("doctor_gpu_not_found")) - gpu_info = "None" + # Non-interactive mode - use traditional flow + console.print() - print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes)) - print_info(t("run_ram_info", total=int(ram))) + # Initialize variables that may have been set by interactive mode + # These will be None in non-interactive mode and will use defaults via resolve() - # Step 2: Resolve model - console.print() - print_step(t("run_checking_model")) + # If no model specified, show old interactive selection + if model is None: + model = _interactive_model_selection(user_registry, settings) + if model is None: + raise typer.Exit(0) - model_info = None - resolved_model_path = model_path + # Detect hardware (needed for defaults) + gpus = detect_gpus() + cpu = detect_cpu_info() + ram = detect_ram_gb() - # Check if model is a path - if Path(model).exists(): - resolved_model_path = Path(model) - print_info(t("run_model_path", path=str(resolved_model_path))) - - # Try to infer model type from path to use default configurations - # Check directory name against known models - dir_name = resolved_model_path.name.lower() - for registered_model in registry.list_all(): - # Check if directory name matches model name or aliases - if dir_name == registered_model.name.lower(): - model_info = registered_model - print_info(f"Detected model type: {registered_model.name}") - break - for alias in registered_model.aliases: - if dir_name == alias.lower() or alias.lower() in dir_name: - model_info = registered_model - print_info(f"Detected model type: {registered_model.name}") - break - if model_info: - break - - # Also check HuggingFace repo format (org--model) - if not model_info: - for registered_model in registry.list_all(): - repo_slug = registered_model.hf_repo.replace("/", "--").lower() - if repo_slug in dir_name or dir_name in repo_slug: - model_info = registered_model - print_info(f"Detected model type: {registered_model.name}") - break - - if not model_info: - print_warning("Could not detect model type from path. Using default parameters.") - console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]") - else: - # Search in registry - matches = registry.search(model) - - if not matches: - print_error(t("run_model_not_found", name=model)) - console.print() - console.print("Available models:") - for m in registry.list_all()[:5]: - console.print(f" - {m.name} ({', '.join(m.aliases[:2])})") - raise typer.Exit(1) - - if len(matches) == 1: - model_info = matches[0] + if gpus: + gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)" + if len(gpus) > 1: + gpu_info += f" + {len(gpus) - 1} more" + print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb)) else: - # Multiple matches - prompt user - console.print() - print_info(t("run_multiple_matches")) - choices = [f"{m.name} ({m.hf_repo})" for m in matches] - selected = prompt_choice(t("run_select_model"), choices) - idx = choices.index(selected) - model_info = matches[idx] + print_warning(t("doctor_gpu_not_found")) + gpu_info = "None" - # Find model path - if model_path is None: - resolved_model_path = _find_model_path(model_info, settings) - if resolved_model_path is None: - print_error(t("run_model_not_found", name=model_info.name)) + print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes)) + print_info(t("run_ram_info", total=int(ram))) + + # Step 2: Resolve model + console.print() + print_step(t("run_checking_model")) + + user_model_obj = None + resolved_model_path = model_path + + # Check if model is a path + if Path(model).exists(): + resolved_model_path = Path(model) + print_info(t("run_model_path", path=str(resolved_model_path))) + + # Try to find in user registry by path + user_model_obj = user_registry.find_by_path(str(resolved_model_path)) + if user_model_obj: + print_info(f"Using registered model: {user_model_obj.name}") + else: + print_warning("Using unregistered model path. Consider adding it with 'kt model add'") + else: + # Search in user registry by name + user_model_obj = user_registry.get_model(model) + + if not user_model_obj: + print_error(t("run_model_not_found", name=model)) console.print() - console.print( - f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}" - ) + + # Show available models + all_models = user_registry.list_models() + if all_models: + console.print("Available registered models:") + for m in all_models[:5]: + console.print(f" - {m.name}") + if len(all_models) > 5: + console.print(f" ... and {len(all_models) - 5} more") + else: + console.print("No models registered yet.") + + console.print() + console.print(f"Add your model with: [cyan]kt model add /path/to/model[/cyan]") + console.print(f"Or scan for models: [cyan]kt model scan[/cyan]") raise typer.Exit(1) - print_info(t("run_model_path", path=str(resolved_model_path))) + # Use model path from registry + resolved_model_path = Path(user_model_obj.path) - # Step 3: Check quantized weights (only if explicitly requested) - resolved_weights_path = None + # Verify path exists + if not resolved_model_path.exists(): + print_error(f"Model path does not exist: {resolved_model_path}") + console.print() + console.print(f"Run 'kt model refresh' to check all models") + raise typer.Exit(1) - # Only use quantized weights if explicitly specified by user - if weights_path is not None: - # User explicitly specified weights path - resolved_weights_path = weights_path - if not resolved_weights_path.exists(): - print_error(t("run_weights_not_found")) - console.print(f" Path: {resolved_weights_path}") + print_info(t("run_model_path", path=str(resolved_model_path))) + + # Step 2.5: Pre-run verification (optional integrity check) + if user_model_obj and user_model_obj.format == "safetensors": + from kt_kernel.cli.utils.model_verifier import pre_operation_verification + + pre_operation_verification(user_model_obj, user_registry, operation_name="running") + + # Step 3: Check quantized weights (only if explicitly requested) + resolved_weights_path = None + + # Only use quantized weights if explicitly specified by user + if weights_path is not None: + # User explicitly specified weights path + resolved_weights_path = weights_path + if not resolved_weights_path.exists(): + print_error(t("run_weights_not_found")) + console.print(f" Path: {resolved_weights_path}") + raise typer.Exit(1) + print_info(f"Using quantized weights: {resolved_weights_path}") + elif quantize: + # User requested quantization + console.print() + print_step(t("run_quantizing")) + # TODO: Implement quantization + print_warning("Quantization not yet implemented. Please run 'kt quant' manually.") raise typer.Exit(1) - print_info(f"Using quantized weights: {resolved_weights_path}") - elif quantize: - # User requested quantization - console.print() - print_step(t("run_quantizing")) - # TODO: Implement quantization - print_warning("Quantization not yet implemented. Please run 'kt quant' manually.") - raise typer.Exit(1) - else: - # Default: use original precision model without quantization - console.print() - print_info("Using original precision model (no quantization)") + else: + # Default: use original precision model without quantization + console.print() + print_info("Using original precision model (no quantization)") # Step 4: Build command - # Resolve all parameters (CLI > model defaults > config > auto-detect) - final_host = host or settings.get("server.host", "0.0.0.0") - final_port = port or settings.get("server.port", 30000) + # Helper to resolve parameter with fallback chain: CLI > config > default + def resolve(cli_val, config_key, default): + if cli_val is not None: + return cli_val + config_val = settings.get(config_key) + return config_val if config_val is not None else default - # Get defaults from model info if available - model_defaults = model_info.default_params if model_info else {} + # Server configuration + final_host = resolve(host, "server.host", "0.0.0.0") + final_port = resolve(port, "server.port", 30000) - # Determine tensor parallel size first (needed for GPU expert calculation) - # Priority: CLI > model defaults > config > auto-detect (with model constraints) - - # Check if explicitly specified by user or configuration - explicitly_specified = ( - tensor_parallel_size # CLI argument (highest priority) - or model_defaults.get("tensor-parallel-size") # Model defaults - or settings.get("inference.tensor_parallel_size") # Config file + # Tensor parallel size: CLI > config > auto-detect from GPUs + final_tensor_parallel_size = resolve( + tensor_parallel_size, "inference.tensor_parallel_size", len(gpus) if gpus else 1 ) - if explicitly_specified: - # Use explicitly specified value - requested_tensor_parallel_size = explicitly_specified - else: - # Auto-detect from GPUs, considering model's max constraint - detected_gpu_count = len(gpus) if gpus else 1 - if model_info and model_info.max_tensor_parallel_size is not None: - # Automatically limit to model's maximum to use as many GPUs as possible - requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size) - else: - requested_tensor_parallel_size = detected_gpu_count - - # Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it - final_tensor_parallel_size = requested_tensor_parallel_size - if model_info and model_info.max_tensor_parallel_size is not None: - if requested_tensor_parallel_size > model_info.max_tensor_parallel_size: - console.print() - print_warning( - f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way " - f"tensor parallelism, but {requested_tensor_parallel_size} was requested. " - f"Reducing to {model_info.max_tensor_parallel_size}." - ) - final_tensor_parallel_size = model_info.max_tensor_parallel_size - # CPU/GPU configuration with smart defaults - # kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes) - total_threads = cpu.cores * cpu.numa_nodes - final_cpu_threads = ( - cpu_threads - or model_defaults.get("kt-cpuinfer") - or settings.get("inference.cpu_threads") - or int(total_threads * 0.8) - ) - - # kt-threadpool-count: default to NUMA node count - final_numa_nodes = ( - numa_nodes - or model_defaults.get("kt-threadpool-count") - or settings.get("inference.numa_nodes") - or cpu.numa_nodes - ) - - # kt-num-gpu-experts: use model-specific computation if available and not explicitly set - if gpu_experts is not None: - # User explicitly set it - final_gpu_experts = gpu_experts - elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus: - # Use model-specific computation function (only if GPUs detected) - vram_per_gpu = gpus[0].vram_gb - compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name] - final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu) - console.print() - print_info( - f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)" - ) - else: - # Fall back to defaults - final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1) + total_threads = cpu.threads # Use logical threads instead of physical cores + final_cpu_threads = resolve(cpu_threads, "inference.cpu_threads", int(total_threads * 0.8)) + final_numa_nodes = resolve(numa_nodes, "inference.numa_nodes", cpu.numa_nodes) + final_gpu_experts = resolve(gpu_experts, "inference.gpu_experts", 1) # KT-kernel options - final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4") - final_kt_gpu_prefill_threshold = ( - kt_gpu_prefill_threshold - or model_defaults.get("kt-gpu-prefill-token-threshold") - or settings.get("inference.kt_gpu_prefill_token_threshold", 4096) - ) + final_kt_method = resolve(kt_method, "inference.kt_method", "AMXINT4") + final_kt_gpu_prefill_threshold = resolve(kt_gpu_prefill_threshold, "inference.kt_gpu_prefill_token_threshold", 4096) # SGLang options - final_attention_backend = ( - attention_backend - or model_defaults.get("attention-backend") - or settings.get("inference.attention_backend", "triton") - ) - final_max_total_tokens = ( - max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000) - ) - final_max_running_requests = ( - max_running_requests - or model_defaults.get("max-running-requests") - or settings.get("inference.max_running_requests", 32) - ) - final_chunked_prefill_size = ( - chunked_prefill_size - or model_defaults.get("chunked-prefill-size") - or settings.get("inference.chunked_prefill_size", 4096) - ) - final_mem_fraction_static = ( - mem_fraction_static - or model_defaults.get("mem-fraction-static") - or settings.get("inference.mem_fraction_static", 0.98) - ) - final_watchdog_timeout = ( - watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000) - ) - final_served_model_name = ( - served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "") - ) + final_attention_backend = resolve(attention_backend, "inference.attention_backend", "flashinfer") + final_max_total_tokens = resolve(max_total_tokens, "inference.max_total_tokens", 40000) + final_max_running_requests = resolve(max_running_requests, "inference.max_running_requests", 32) + final_chunked_prefill_size = resolve(chunked_prefill_size, "inference.chunked_prefill_size", 4096) + final_mem_fraction_static = resolve(mem_fraction_static, "inference.mem_fraction_static", 0.98) + final_watchdog_timeout = resolve(watchdog_timeout, "inference.watchdog_timeout", 3000) + final_served_model_name = resolve(served_model_name, "inference.served_model_name", "") # Performance flags - if disable_shared_experts_fusion is not None: - final_disable_shared_experts_fusion = disable_shared_experts_fusion - elif "disable-shared-experts-fusion" in model_defaults: - final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"] - else: - final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False) + final_disable_shared_experts_fusion = resolve( + disable_shared_experts_fusion, "inference.disable_shared_experts_fusion", True + ) - # Pass all model default params to handle any extra parameters - extra_params = model_defaults if model_info else {} + # Pass extra CLI parameters + extra_params = {} + + # Parser parameters (from interactive mode or None in non-interactive mode) + final_tool_call_parser = None + final_reasoning_parser = None + if "tool_call_parser" in locals() and tool_call_parser: + final_tool_call_parser = tool_call_parser + if "reasoning_parser" in locals() and reasoning_parser: + final_reasoning_parser = reasoning_parser cmd = _build_sglang_command( model_path=resolved_model_path, weights_path=resolved_weights_path, - model_info=model_info, host=final_host, port=final_port, gpu_experts=final_gpu_experts, @@ -490,6 +472,8 @@ def _run_impl( watchdog_timeout=final_watchdog_timeout, served_model_name=final_served_model_name, disable_shared_experts_fusion=final_disable_shared_experts_fusion, + tool_call_parser=final_tool_call_parser, + reasoning_parser=final_reasoning_parser, settings=settings, extra_model_params=extra_params, extra_cli_args=extra_cli_args, @@ -508,11 +492,9 @@ def _run_impl( console.print() print_step("Configuration") - # Model info - if model_info: - console.print(f" Model: [bold]{model_info.name}[/bold]") - else: - console.print(f" Model: [bold]{resolved_model_path.name}[/bold]") + # Display model name + model_display_name = user_model_obj.name if user_model_obj else resolved_model_path.name + console.print(f" Model: [bold]{model_display_name}[/bold]") console.print(f" Path: [dim]{resolved_model_path}[/dim]") @@ -572,88 +554,13 @@ def _run_impl( raise typer.Exit(1) -def _find_model_path(model_info: ModelInfo, settings, max_depth: int = 3) -> Optional[Path]: - """Find the model path on disk by searching all configured model paths. - - Args: - model_info: Model information to search for - settings: Settings instance - max_depth: Maximum depth to search within each model path (default: 3) - - Returns: - Path to the model directory, or None if not found - """ - model_paths = settings.get_model_paths() - - # Generate possible names to search for - possible_names = [ - model_info.name, - model_info.name.lower(), - model_info.name.replace(" ", "-"), - model_info.hf_repo.split("/")[-1], - model_info.hf_repo.replace("/", "--"), - ] - - # Add alias-based names - for alias in model_info.aliases: - possible_names.append(alias) - possible_names.append(alias.lower()) - - # Search in all configured model directories - for models_dir in model_paths: - if not models_dir.exists(): - continue - - # Search recursively up to max_depth - for depth in range(max_depth): - for name in possible_names: - if depth == 0: - # Direct children: models_dir / name - search_paths = [models_dir / name] - else: - # Nested: use rglob to find directories matching the name - search_paths = list(models_dir.rglob(name)) - - for path in search_paths: - if path.exists() and (path / "config.json").exists(): - return path - - return None - - -def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]: - """Find the quantized weights path on disk by searching all configured paths.""" - model_paths = settings.get_model_paths() - weights_dir = settings.weights_dir - - # Check common patterns - base_names = [ - model_info.name, - model_info.name.lower(), - model_info.hf_repo.split("/")[-1], - ] - - suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"] - - # Prepare search directories - search_dirs = [weights_dir] if weights_dir else [] - search_dirs.extend(model_paths) - - for base in base_names: - for suffix in suffixes: - for dir_path in search_dirs: - if dir_path: - path = dir_path / f"{base}{suffix}" - if path.exists(): - return path - - return None +# Dead code removed: _find_model_path() and _find_weights_path() +# These functions were part of the old builtin model system def _build_sglang_command( model_path: Path, weights_path: Optional[Path], - model_info: Optional[ModelInfo], host: str, port: int, gpu_experts: int, @@ -670,6 +577,8 @@ def _build_sglang_command( watchdog_timeout: int, served_model_name: str, disable_shared_experts_fusion: bool, + tool_call_parser: Optional[str], + reasoning_parser: Optional[str], settings, extra_model_params: Optional[dict] = None, # New parameter for additional params extra_cli_args: Optional[list[str]] = None, # Extra args from CLI to pass to sglang @@ -700,9 +609,6 @@ def _build_sglang_command( elif cpu_threads > 0 or gpu_experts > 1: # CPU offloading configured - use kt-kernel use_kt_kernel = True - elif model_info and model_info.type == "moe": - # MoE model - likely needs kt-kernel for expert offloading - use_kt_kernel = True if use_kt_kernel: # Add kt-weight-path: use quantized weights if available, otherwise use model path @@ -723,6 +629,7 @@ def _build_sglang_command( kt_method, "--kt-gpu-prefill-token-threshold", str(kt_gpu_prefill_threshold), + "--kt-enable-dynamic-expert-update", # Enable dynamic expert updates ] ) @@ -757,6 +664,16 @@ def _build_sglang_command( if disable_shared_experts_fusion: cmd.append("--disable-shared-experts-fusion") + # Add FP8 backend if using FP8 method + if "FP8" in kt_method.upper(): + cmd.extend(["--fp8-gemm-backend", "triton"]) + + # Add parsers if specified + if tool_call_parser: + cmd.extend(["--tool-call-parser", tool_call_parser]) + if reasoning_parser: + cmd.extend(["--reasoning-parser", reasoning_parser]) + # Add any extra parameters from model defaults that weren't explicitly handled if extra_model_params: # List of parameters already handled above @@ -801,30 +718,31 @@ def _build_sglang_command( return cmd -def _interactive_model_selection(registry, settings) -> Optional[str]: +def _interactive_model_selection(user_registry, settings) -> Optional[str]: """Show interactive model selection interface. Returns: Selected model name or None if cancelled. """ from rich.panel import Panel - from rich.table import Table from rich.prompt import Prompt - from kt_kernel.cli.i18n import get_lang + # Get all user models + all_models = user_registry.list_models() - lang = get_lang() - - # Find local models first - local_models = registry.find_local_models() - - # Get all registered models - all_models = registry.list_all() + if not all_models: + console.print() + print_warning("No models registered.") + console.print() + console.print(f" Add models with: [cyan]kt model scan[/cyan]") + console.print(f" Or manually: [cyan]kt model add /path/to/model[/cyan]") + console.print() + return None console.print() console.print( Panel.fit( - t("run_select_model_title"), + "Select a model to run", border_style="cyan", ) ) @@ -834,54 +752,30 @@ def _interactive_model_selection(registry, settings) -> Optional[str]: choices = [] choice_map = {} # index -> model name - # Section 1: Local models (downloaded) - if local_models: - console.print(f"[bold green]{t('run_local_models')}[/bold green]") - console.print() - - for i, (model_info, path) in enumerate(local_models, 1): - desc = model_info.description_zh if lang == "zh" else model_info.description - short_desc = desc[:50] + "..." if len(desc) > 50 else desc - console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]") - console.print(f" [dim]{short_desc}[/dim]") - console.print(f" [dim]{path}[/dim]") - choices.append(str(i)) - choice_map[str(i)] = model_info.name - - console.print() - - # Section 2: All registered models (for reference) - start_idx = len(local_models) + 1 - console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]") + # Show all user models + console.print(f"[bold green]Available Models:[/bold green]") console.print() - # Filter out already shown local models - local_model_names = {m.name for m, _ in local_models} - - for i, model_info in enumerate(all_models, start_idx): - if model_info.name in local_model_names: - continue - - desc = model_info.description_zh if lang == "zh" else model_info.description - short_desc = desc[:50] + "..." if len(desc) > 50 else desc - console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]") - console.print(f" [dim]{short_desc}[/dim]") - console.print(f" [dim]{model_info.hf_repo}[/dim]") + for i, model in enumerate(all_models, 1): + # Check if path exists + path_status = "✓" if model.path_exists() else "✗ Missing" + console.print(f" [cyan][{i}][/cyan] [bold]{model.name}[/bold] [{path_status}]") + console.print(f" [dim]{model.format} - {model.path}[/dim]") choices.append(str(i)) - choice_map[str(i)] = model_info.name + choice_map[str(i)] = model.name console.print() # Add cancel option cancel_idx = str(len(choices) + 1) - console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]") + console.print(f" [cyan][{cancel_idx}][/cyan] [dim]Cancel[/dim]") choices.append(cancel_idx) console.print() # Prompt for selection try: selection = Prompt.ask( - t("run_select_model_prompt"), + "Select model", choices=choices, default="1" if choices else cancel_idx, ) diff --git a/kt-kernel/python/cli/completions/kt-completion.bash b/kt-kernel/python/cli/completions/kt-completion.bash index 8f1d3be..4989878 100644 --- a/kt-kernel/python/cli/completions/kt-completion.bash +++ b/kt-kernel/python/cli/completions/kt-completion.bash @@ -9,7 +9,7 @@ _kt_completion() { prev="${COMP_WORDS[COMP_CWORD-1]}" # Main commands - local commands="version run chat quant bench microbench doctor model config sft" + local commands="version run chat quant edit bench microbench doctor model config sft" # Global options local global_opts="--help --version" @@ -36,6 +36,10 @@ _kt_completion() { local quant_opts="--method --output --help" COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) ) ;; + edit) + local edit_opts="--help" + COMPREPLY=( $(compgen -W "${edit_opts}" -- ${cur}) ) + ;; bench|microbench) local bench_opts="--model --config --help" COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) ) diff --git a/kt-kernel/python/cli/i18n.py b/kt-kernel/python/cli/i18n.py index af90cba..9dbe1c9 100644 --- a/kt-kernel/python/cli/i18n.py +++ b/kt-kernel/python/cli/i18n.py @@ -190,6 +190,70 @@ MESSAGES: dict[str, dict[str, str]] = { "quant_progress": "Quantizing...", "quant_complete": "Quantization complete!", "quant_input_not_found": "Input model not found at {path}", + "quant_cpu_threads": "CPU threads: {threads}", + "quant_numa_nodes": "NUMA nodes: {nodes}", + "quant_time_warning": "Quantization may take 30-60 minutes depending on model size.", + "quant_disk_analysis": "Disk Space Analysis:", + "quant_source_size": "Source model size:", + "quant_estimated_size": "Estimated output size:", + "quant_available_space": "Available space:", + "quant_insufficient_space": "WARNING: Insufficient disk space!", + "quant_required_space": "Required space (with 20% buffer):", + "quant_shortage": "Shortage:", + "quant_may_fail": "Quantization may fail or produce incomplete files.", + "quant_continue_anyway": "Continue anyway?", + "quant_settings": "Quantization Settings:", + "quant_registered": "Quantized model registered: {name}", + "quant_view_with": "View with:", + "quant_use_with": "Use with:", + "quant_register_failed": "Failed to auto-register model: {error}", + "quant_output_exists": "Output path already exists: {path}", + "quant_using_unique": "Using unique name: {path}", + # Interactive quant + "quant_interactive_title": "Interactive Quantization Configuration", + "quant_new_model_notice": "⚠ Note: Some newer models cannot be quantized yet (conversion script not adapted). Recommended to use the original precision for inference (no weight conversion needed).", + "quant_no_moe_models": "No MoE models found for quantization.", + "quant_only_moe": "Only MoE models (e.g., DeepSeek-V3) can be quantized to AMX format.", + "quant_add_models": "Add models with: {command}", + "quant_moe_available": "MoE Models Available for Quantization:", + "quant_select_model": "Select model to quantize", + "quant_invalid_choice": "Invalid choice", + "quant_step2_method": "Step 2: Quantization Method", + "quant_method_label": "Quantization Method:", + "quant_int4_desc": "INT4", + "quant_int8_desc": "INT8", + "quant_select_method": "Select quantization method", + "quant_input_type_label": "Input Weight Type:", + "quant_fp8_desc": "FP8 (for 8-bit float weights)", + "quant_fp16_desc": "FP16 (for 16-bit float weights)", + "quant_bf16_desc": "BF16 (for Brain Float 16 weights)", + "quant_select_input_type": "Select input type", + "quant_step3_cpu": "Step 3: CPU Configuration", + "quant_cpu_threads_prompt": "CPU Threads (1 to {max})", + "quant_numa_nodes_prompt": "NUMA Nodes (1 to {max})", + "quant_use_gpu_label": "Use GPU for conversion?", + "quant_gpu_speedup": "GPU can significantly speed up the quantization process", + "quant_enable_gpu": "Enable GPU acceleration?", + "quant_step4_output": "Step 4: Output Path", + "quant_default_path": "Default:", + "quant_use_default": "Use default output path?", + "quant_custom_path": "Enter custom output path", + "quant_output_exists_warn": "⚠ Output path already exists: {path}", + "quant_using_unique_name": "→ Using unique name: {path}", + "quant_config_summary": "Configuration Summary", + "quant_summary_model": "Model:", + "quant_summary_method": "Method:", + "quant_summary_input_type": "Input Type:", + "quant_summary_cpu_threads": "CPU Threads:", + "quant_summary_numa": "NUMA Nodes:", + "quant_summary_gpu": "Use GPU:", + "quant_summary_output": "Output Path:", + "quant_start_question": "Start quantization?", + "quant_cancelled": "Cancelled", + "quant_config_complete": "Configuration complete", + "quant_time_elapsed": "Time elapsed:", + "yes": "Yes", + "no": "No", # SFT command "sft_mode_train": "Training mode", "sft_mode_chat": "Chat mode", @@ -247,6 +311,113 @@ MESSAGES: dict[str, dict[str, str]] = { "chat_proxy_detected": "Proxy detected in environment", "chat_proxy_confirm": "Use proxy for connection?", "chat_proxy_disabled": "Proxy disabled for this session", + "chat_openai_required": "OpenAI Python SDK is required for chat functionality.", + "chat_install_hint": "Install it with:", + "chat_title": "KTransformers Chat", + "chat_server": "Server", + "chat_temperature": "Temperature", + "chat_max_tokens": "Max tokens", + "chat_help_hint": "Type '/help' for commands, '/quit' to exit", + "chat_connecting": "Connecting to server...", + "chat_no_models": "No models available on server", + "chat_model_not_found": "Model '{model}' not found. Available models: {available}", + "chat_connected": "Connected to model: {model}", + "chat_connect_failed": "Failed to connect to server: {error}", + "chat_server_not_running": "Make sure the model server is running:", + "chat_user_prompt": "You", + "chat_assistant_prompt": "Assistant", + "chat_generation_error": "Error generating response: {error}", + "chat_interrupted": "Chat interrupted. Goodbye!", + "chat_history_saved": "History saved to: {path}", + "chat_goodbye": "Goodbye!", + "chat_help_title": "Available Commands:", + "chat_help_content": "/help, /h - Show this help message\n/quit, /exit, /q - Exit chat\n/clear, /c - Clear conversation history\n/history, /hist - Show conversation history\n/info, /i - Show current settings\n/retry, /r - Regenerate last response", + "chat_history_cleared": "Conversation history cleared", + "chat_no_history": "No conversation history", + "chat_history_title": "History ({count} messages)", + "chat_info_title": "Current Settings:", + "chat_info_content": "Temperature: {temperature}\nMax tokens: {max_tokens}\nMessages: {messages}", + "chat_retrying": "Retrying last response...", + "chat_no_retry": "No previous response to retry", + "chat_unknown_command": "Unknown command: {command}", + "chat_unknown_hint": "Type /help for available commands", + # Run Interactive + "run_int_no_moe_models": "No MoE GPU models found.", + "run_int_add_models": "Add models with: kt model scan", + "run_int_list_all": "List all models: kt model list --all", + "run_int_step1_title": "Step 1: Select Model (GPU MoE Models)", + "run_int_select_model": "Select model", + "run_int_step2_title": "Step 2: Select Inference Method", + "run_int_method_raw": "RAW Precision (FP8/FP8_PERCHANNEL/BF16/RAWINT4)", + "run_int_method_amx": "AMX Quantization (INT4/INT8)", + "run_int_method_gguf": "GGUF (Llamafile)", + "run_int_method_saved": "Use Saved Configuration", + "run_int_select_method": "Select inference method", + "run_int_raw_precision": "RAW Precision:", + "run_int_select_precision": "Select precision", + "run_int_amx_method": "AMX Method:", + "run_int_select_amx": "Select AMX method", + "run_int_step3_title": "Step 3: NUMA and CPU Configuration", + "run_int_numa_nodes": "NUMA Nodes (1-{max})", + "run_int_cpu_threads": "CPU Threads per NUMA (1-{max})", + "run_int_amx_warning": "⚠ Warning: AMX INT4/INT8 requires compatible CPU. Check with: kt doctor", + "run_int_step4_title": "Step 4: GPU Experts Configuration", + "run_int_gpu_experts": "GPU Experts per Layer (0-{max})", + "run_int_gpu_experts_info": "Total experts: {total}, Activated per token: {active}", + "run_int_step5_title": "Step 5: KV Cache Configuration", + "run_int_kv_cache_size": "KV Cache Size (tokens)", + "run_int_chunk_prefill": "Enable Chunk Prefill?", + "run_int_chunk_size": "Chunk Prefill Size (tokens)", + "run_int_gpu_prefill_threshold": "GPU Prefill Threshold (tokens)", + "run_int_step6_title": "Step 6: GPU Selection and Tensor Parallelism", + "run_int_available_gpus": "Available GPUs:", + "run_int_gpu_id": "GPU {id}", + "run_int_vram_info": "{name} ({total:.1f}GB total, {free:.1f}GB free)", + "run_int_select_gpus": "Select GPU IDs (comma-separated)", + "run_int_invalid_gpu_range": "All GPU IDs must be between 0 and {max}", + "run_int_tp_size": "TP Size (must be power of 2: 1,2,4,8...)", + "run_int_tp_mismatch": "TP size must match number of selected GPUs ({count})", + "run_int_tp_not_power_of_2": "TP size must be a power of 2", + "run_int_mem_fraction": "Static Memory Fraction (0.0-1.0)", + "run_int_using_saved_mem": "Using saved memory fraction: {fraction}", + "run_int_step7_title": "Step 7: Parser Configuration (Optional)", + "run_int_tool_call_parser": "Tool Call Parser (press Enter to skip)", + "run_int_reasoning_parser": "Reasoning Parser (press Enter to skip)", + "run_int_step8_title": "Step 8: Host and Port Configuration", + "run_int_host": "Host", + "run_int_port": "Port", + "run_int_port_occupied": "⚠ Port {port} is already in use", + "run_int_port_suggestion": "Suggested available port: {port}", + "run_int_use_suggested": "Use suggested port?", + "run_int_saved_configs": "Saved Configurations:", + "run_int_config_name": "Configuration {num}", + "run_int_kt_method": "KT Method:", + "run_int_numa_nodes_label": "NUMA Nodes:", + "run_int_cpu_threads_label": "CPU Threads:", + "run_int_gpu_experts_label": "GPU Experts:", + "run_int_tp_size_label": "TP Size:", + "run_int_mem_fraction_label": "Memory Fraction:", + "run_int_server_label": "Server:", + "run_int_kv_cache_label": "KV Cache:", + "run_int_chunk_prefill_label": "Chunk Prefill:", + "run_int_gpu_prefill_label": "GPU Prefill Thr:", + "run_int_tool_parser_label": "Tool Call Parser:", + "run_int_reasoning_parser_label": "Reasoning Parser:", + "run_int_command_label": "Command:", + "run_int_select_config": "Select configuration", + "run_int_gpu_select_required": "Please select {tp} GPUs (TP size from saved config)", + "run_int_port_check_title": "Port Configuration", + "run_int_port_checking": "Checking port {port} availability...", + "run_int_port_available": "Port {port} is available", + "run_int_saved_config_title": "Saved Configuration", + "run_int_save_config_title": "Save Configuration", + "run_int_save_config_prompt": "Save this configuration for future use?", + "run_int_config_name_prompt": "Configuration name", + "run_int_config_name_default": "Config {timestamp}", + "run_int_config_saved": "Configuration saved: {name}", + "run_int_config_summary": "Configuration Complete", + "run_int_model_label": "Model:", + "run_int_selected_gpus_label": "Selected GPUs:", # Model command "model_supported_title": "KTransformers Supported Models", "model_column_model": "Model", @@ -282,6 +453,180 @@ MESSAGES: dict[str, dict[str, str]] = { "model_column_name": "Name", "model_column_hf_repo": "HuggingFace Repo", "model_column_aliases": "Aliases", + # Model management - new user registry system + "model_no_registered_models": "No models registered yet.", + "model_scan_hint": "Scan for models: kt model scan", + "model_add_hint": "Add a model: kt model add /path/to/model", + "model_registered_models_title": "Registered Models", + "model_column_format": "Format", + "model_column_repo": "Repository", + "model_column_sha256": "SHA256", + "model_non_moe_hidden_hint": "Detected {count} non-MoE models, use kt model list --all to show all", + "model_usage_title": "Common Operations:", + "model_usage_info": "View details:", + "model_usage_edit": "Edit model:", + "model_usage_verify": "Verify integrity:", + "model_usage_quant": "Quantize model:", + "model_usage_run": "Run model:", + "model_usage_scan": "Scan for models:", + "model_usage_add": "Add model:", + "model_usage_verbose": "View with file details:", + "model_no_storage_paths": "No storage paths configured.", + "model_add_path_hint": "Add a storage path with: kt config set model.storage_paths /path/to/models", + "model_scanning_paths": "Scanning configured storage paths...", + "model_scanning_progress": "Scanning: {path}", + "model_scan_warnings_title": "Warnings", + "model_scan_no_models_found": "No models found in configured paths.", + "model_scan_check_paths_hint": "Check your storage paths: kt config get model.storage_paths", + "model_scan_min_size_hint": "Folders must be ≥{size}GB to be detected as models.", + "model_scan_found_title": "Found {count} new model(s)", + "model_column_path": "Path", + "model_column_size": "Size", + "model_scan_auto_adding": "Auto-adding models...", + "model_added": "Added: {name}", + "model_add_failed": "Failed to add {name}: {error}", + "model_scan_complete": "Scan complete! Added {count} model(s).", + "model_scan_interactive_prompt": "Commands: edit | del | done", + "model_scan_cmd_edit": "Set custom name for model", + "model_scan_cmd_delete": "Skip this model", + "model_scan_cmd_done": "Finish and add models", + "model_scan_marked_skip": "Skipped model #{id}", + "model_scan_invalid_id": "Invalid model ID: {id}", + "model_scan_invalid_command": "Invalid command. Use: edit | del | done", + "model_scan_edit_model": "Edit model {id}", + "model_scan_edit_note": "You can change the model name before adding it to registry", + "model_scan_adding_models": "Adding {count} model(s)...", + "model_scan_next_steps": "Next Steps", + "model_scan_view_hint": "View registered models: kt model list", + "model_scan_edit_hint": "Edit model details: kt model edit ", + "model_scan_no_models_added": "No models were added.", + "model_add_path_not_exist": "Error: Path does not exist: {path}", + "model_add_not_directory": "Error: Path is not a directory: {path}", + "model_add_already_registered": "This path is already registered as: {name}", + "model_add_view_hint": "View with: kt model info {name}", + "model_add_scanning": "Scanning model files...", + "model_add_scan_failed": "Failed to scan model: {error}", + "model_add_no_model_files": "No model files found in {path}", + "model_add_supported_formats": "Supported: *.safetensors, *.gguf (folder ≥10GB)", + "model_add_detected": "Detected: {format} format, {size}, {count} file(s)", + "model_add_name_conflict": "Name '{name}' already exists.", + "model_add_prompt_name": "Enter a name for this model", + "model_add_name_exists": "Name already exists. Please choose another name:", + "model_add_configure_repo": "Configure repository information for SHA256 verification?", + "model_add_repo_type_prompt": "Select repository type:", + "model_add_choice": "Choice", + "model_add_repo_id_prompt": "Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)", + "model_add_success": "Successfully added model: {name}", + "model_add_verify_hint": "Verify integrity: kt model verify {name}", + "model_add_edit_later_hint": "Edit details later: kt model edit {name}", + "model_add_failed_generic": "Failed to add model: {error}", + "model_edit_not_found": "Model '{name}' not found.", + "model_edit_list_hint": "List models: kt model list", + "model_edit_current_config": "Current Configuration", + "model_edit_what_to_edit": "What would you like to edit?", + "model_edit_option_name": "Edit name", + "model_edit_option_repo": "Configure repository info", + "model_edit_option_delete": "Delete this model", + "model_edit_option_cancel": "Cancel / Exit", + "model_edit_choice_prompt": "Select option", + "model_edit_new_name": "Enter new name", + "model_edit_name_conflict": "Name '{name}' already exists. Please choose another:", + "model_edit_name_updated": "Name updated: {old} → {new}", + "model_edit_repo_type_prompt": "Repository type (or enter to remove repo info):", + "model_edit_repo_remove": "Remove repository info", + "model_edit_repo_id_prompt": "Enter repository ID", + "model_edit_repo_removed": "Repository info removed", + "model_edit_repo_updated": "Repository configured: {repo_type} → {repo_id}", + "model_edit_delete_warning": "Delete model '{name}' from registry?", + "model_edit_delete_note": "Note: This only removes the registry entry. Model files in {path} will NOT be deleted.", + "model_edit_delete_confirm": "Confirm deletion?", + "model_edit_deleted": "Model '{name}' deleted from registry", + "model_edit_delete_cancelled": "Deletion cancelled", + "model_edit_cancelled": "Edit cancelled", + # Model edit - Interactive selection + "model_edit_select_title": "Select Model to Edit", + "model_edit_select_model": "Select model", + "model_edit_invalid_choice": "Invalid choice", + "model_edit_no_models": "No models found in registry.", + "model_edit_add_hint_scan": "Add models with:", + "model_edit_add_hint_add": "Or:", + # Model edit - Display + "model_edit_gpu_links": "GPU Links:", + # Model edit - Menu options + "model_edit_manage_gpu_links": "Manage GPU Links", + "model_edit_save_changes": "Save changes", + "model_edit_has_changes": "(has changes)", + "model_edit_no_changes": "(no changes)", + # Model edit - Pending changes messages + "model_edit_name_pending": "Name will be updated when you save changes.", + "model_edit_repo_remove_pending": "Repository info will be removed when you save changes.", + "model_edit_repo_update_pending": "Repository info will be updated when you save changes.", + # Model edit - GPU link management + "model_edit_gpu_links_title": "Manage GPU Links for {name}", + "model_edit_current_gpu_links": "Current GPU links:", + "model_edit_no_gpu_links": "No GPU links configured.", + "model_edit_gpu_options": "Options:", + "model_edit_gpu_add": "Add GPU link", + "model_edit_gpu_remove": "Remove GPU link", + "model_edit_gpu_clear": "Clear all GPU links", + "model_edit_gpu_back": "Back to main menu", + "model_edit_gpu_choose_option": "Choose option", + "model_edit_gpu_none_available": "No GPU models available to link.", + "model_edit_gpu_available_models": "Available GPU models:", + "model_edit_gpu_already_linked": "(already linked)", + "model_edit_gpu_enter_number": "Enter GPU model number to add", + "model_edit_gpu_link_pending": "GPU link will be added when you save changes: {name}", + "model_edit_gpu_already_exists": "This GPU model is already linked.", + "model_edit_gpu_invalid_choice": "Invalid choice.", + "model_edit_gpu_invalid_input": "Invalid input.", + "model_edit_gpu_none_to_remove": "No GPU links to remove.", + "model_edit_gpu_choose_to_remove": "Choose GPU link to remove:", + "model_edit_gpu_enter_to_remove": "Enter number to remove", + "model_edit_gpu_remove_pending": "GPU link will be removed when you save changes: {name}", + "model_edit_gpu_none_to_clear": "No GPU links to clear.", + "model_edit_gpu_clear_confirm": "Remove all GPU links?", + "model_edit_gpu_clear_pending": "All GPU links will be removed when you save changes.", + "model_edit_cancelled_short": "Cancelled.", + # Model edit - Save operation + "model_edit_no_changes_to_save": "No changes to save.", + "model_edit_saving": "Saving changes...", + "model_edit_saved": "Changes saved successfully!", + "model_edit_updated_config": "Updated Configuration:", + "model_edit_repo_changed_warning": "⚠ Repository information has changed.", + "model_edit_verify_hint": "Run [cyan]kt model verify[/cyan] to verify model integrity with SHA256 checksums.", + "model_edit_discard_changes": "Discard unsaved changes?", + "model_info_not_found": "Model '{name}' not found.", + "model_info_list_hint": "List all models: kt model list", + "model_remove_not_found": "Model '{name}' not found.", + "model_remove_list_hint": "List models: kt model list", + "model_remove_warning": "Remove model '{name}' from registry?", + "model_remove_note": "Note: This only removes the registry entry. Model files will NOT be deleted from {path}.", + "model_remove_confirm": "Confirm removal?", + "model_remove_cancelled": "Removal cancelled", + "model_removed": "Model '{name}' removed from registry", + "model_remove_failed": "Failed to remove model: {error}", + "model_refresh_checking": "Checking model paths...", + "model_refresh_all_valid": "All models are valid! ({count} model(s) checked)", + "model_refresh_total": "Total models: {total}", + "model_refresh_missing_found": "Found {count} missing model(s)", + "model_refresh_suggestions": "Suggested Actions", + "model_refresh_remove_hint": "Remove from registry: kt model remove ", + "model_refresh_rescan_hint": "Re-scan for models: kt model scan", + "model_verify_not_found": "Model '{name}' not found.", + "model_verify_list_hint": "List models: kt model list", + "model_verify_no_repo": "Model '{name}' has no repository information configured.", + "model_verify_config_hint": "Configure repository: kt model edit {name}", + "model_verify_path_missing": "Model path does not exist: {path}", + "model_verify_starting": "Verifying model integrity...", + "model_verify_progress": "Repository: {repo_type} → {repo_id}", + "model_verify_not_implemented": "SHA256 verification not implemented yet", + "model_verify_future_note": "This feature will fetch official SHA256 hashes from {repo_type} and compare with local files.", + "model_verify_passed": "Verification passed! All files match official hashes.", + "model_verify_failed": "Verification failed! {count} file(s) have hash mismatches.", + "model_verify_all_no_repos": "No models have repository information configured.", + "model_verify_all_config_hint": "Configure repos using: kt model edit ", + "model_verify_all_found": "Found {count} model(s) with repository info", + "model_verify_all_manual_hint": "Verify specific model: kt model verify ", # Coming soon "feature_coming_soon": "This feature is coming soon...", }, @@ -465,6 +810,70 @@ MESSAGES: dict[str, dict[str, str]] = { "quant_progress": "正在量化...", "quant_complete": "量化完成!", "quant_input_not_found": "未找到输入模型: {path}", + "quant_cpu_threads": "CPU 线程数: {threads}", + "quant_numa_nodes": "NUMA 节点数: {nodes}", + "quant_time_warning": "量化可能需要 30-60 分钟,具体取决于模型大小。", + "quant_disk_analysis": "磁盘空间分析:", + "quant_source_size": "源模型大小:", + "quant_estimated_size": "预估输出大小:", + "quant_available_space": "可用空间:", + "quant_insufficient_space": "警告:磁盘空间不足!", + "quant_required_space": "所需空间(含20%缓冲):", + "quant_shortage": "不足:", + "quant_may_fail": "量化可能失败或生成不完整的文件。", + "quant_continue_anyway": "仍然继续?", + "quant_settings": "量化设置:", + "quant_registered": "量化模型已注册:{name}", + "quant_view_with": "查看:", + "quant_use_with": "使用:", + "quant_register_failed": "自动注册模型失败:{error}", + "quant_output_exists": "输出路径已存在:{path}", + "quant_using_unique": "使用唯一名称:{path}", + # Interactive quant + "quant_interactive_title": "交互式量化配置", + "quant_new_model_notice": "⚠ 注意:部分新模型暂时无法量化(转换脚本未适配),推荐使用原精度进行推理(无需转换权重)。", + "quant_no_moe_models": "未找到可量化的 MoE 模型。", + "quant_only_moe": "只有 MoE 模型(如 DeepSeek-V3)可以被量化为 AMX 格式。", + "quant_add_models": "添加模型:{command}", + "quant_moe_available": "可量化的 MoE 模型:", + "quant_select_model": "选择要量化的模型", + "quant_invalid_choice": "无效选择", + "quant_step2_method": "第 2 步:量化方法", + "quant_method_label": "量化方法:", + "quant_int4_desc": "INT4", + "quant_int8_desc": "INT8", + "quant_select_method": "选择量化方法", + "quant_input_type_label": "输入权重类型:", + "quant_fp8_desc": "FP8(适用于 8 位浮点权重)", + "quant_fp16_desc": "FP16(适用于 16 位浮点权重)", + "quant_bf16_desc": "BF16(适用于 Brain Float 16 权重)", + "quant_select_input_type": "选择输入类型", + "quant_step3_cpu": "第 3 步:CPU 配置", + "quant_cpu_threads_prompt": "CPU 线程数(1 到 {max})", + "quant_numa_nodes_prompt": "NUMA 节点数(1 到 {max})", + "quant_use_gpu_label": "是否使用 GPU 进行转换?", + "quant_gpu_speedup": "GPU 可以显著加快量化速度", + "quant_enable_gpu": "启用 GPU 加速?", + "quant_step4_output": "第 4 步:输出路径", + "quant_default_path": "默认:", + "quant_use_default": "使用默认输出路径?", + "quant_custom_path": "输入自定义输出路径", + "quant_output_exists_warn": "⚠ 输出路径已存在:{path}", + "quant_using_unique_name": "→ 使用唯一名称:{path}", + "quant_config_summary": "配置摘要", + "quant_summary_model": "模型:", + "quant_summary_method": "方法:", + "quant_summary_input_type": "输入类型:", + "quant_summary_cpu_threads": "CPU 线程数:", + "quant_summary_numa": "NUMA 节点数:", + "quant_summary_gpu": "使用 GPU:", + "quant_summary_output": "输出路径:", + "quant_start_question": "开始量化?", + "quant_cancelled": "已取消", + "quant_config_complete": "配置完成", + "quant_time_elapsed": "耗时:", + "yes": "是", + "no": "否", # SFT command "sft_mode_train": "训练模式", "sft_mode_chat": "聊天模式", @@ -522,6 +931,113 @@ MESSAGES: dict[str, dict[str, str]] = { "chat_proxy_detected": "检测到环境中存在代理设置", "chat_proxy_confirm": "是否使用代理连接?", "chat_proxy_disabled": "已在本次会话中禁用代理", + "chat_openai_required": "聊天功能需要 OpenAI Python SDK。", + "chat_install_hint": "安装命令:", + "chat_title": "KTransformers 对话", + "chat_server": "服务器", + "chat_temperature": "温度", + "chat_max_tokens": "最大 tokens", + "chat_help_hint": "输入 '/help' 查看命令,'/quit' 退出", + "chat_connecting": "正在连接服务器...", + "chat_no_models": "服务器上没有可用模型", + "chat_model_not_found": "未找到模型 '{model}'。可用模型:{available}", + "chat_connected": "已连接到模型:{model}", + "chat_connect_failed": "连接服务器失败:{error}", + "chat_server_not_running": "请确保模型服务器正在运行:", + "chat_user_prompt": "用户", + "chat_assistant_prompt": "助手", + "chat_generation_error": "生成回复时出错:{error}", + "chat_interrupted": "对话已中断。再见!", + "chat_history_saved": "历史记录已保存到:{path}", + "chat_goodbye": "再见!", + "chat_help_title": "可用命令:", + "chat_help_content": "/help, /h - 显示此帮助信息\n/quit, /exit, /q - 退出聊天\n/clear, /c - 清除对话历史\n/history, /hist - 显示对话历史\n/info, /i - 显示当前设置\n/retry, /r - 重新生成上一个回复", + "chat_history_cleared": "对话历史已清除", + "chat_no_history": "暂无对话历史", + "chat_history_title": "历史记录({count} 条消息)", + "chat_info_title": "当前设置:", + "chat_info_content": "温度:{temperature}\n最大 tokens:{max_tokens}\n消息数:{messages}", + "chat_retrying": "正在重试上一个回复...", + "chat_no_retry": "没有可重试的回复", + "chat_unknown_command": "未知命令:{command}", + "chat_unknown_hint": "输入 /help 查看可用命令", + # Run Interactive + "run_int_no_moe_models": "未找到 MoE GPU 模型。", + "run_int_add_models": "添加模型:kt model scan", + "run_int_list_all": "列出所有模型:kt model list --all", + "run_int_step1_title": "第 1 步:选择模型(GPU MoE 模型)", + "run_int_select_model": "选择模型", + "run_int_step2_title": "第 2 步:选择推理方法", + "run_int_method_raw": "RAW 精度(FP8/FP8_PERCHANNEL/BF16/RAWINT4)", + "run_int_method_amx": "AMX 量化(INT4/INT8)", + "run_int_method_gguf": "GGUF(Llamafile)", + "run_int_method_saved": "使用已保存的配置", + "run_int_select_method": "选择推理方法", + "run_int_raw_precision": "RAW 精度:", + "run_int_select_precision": "选择精度", + "run_int_amx_method": "AMX 方法:", + "run_int_select_amx": "选择 AMX 方法", + "run_int_step3_title": "第 3 步:NUMA 和 CPU 配置", + "run_int_numa_nodes": "NUMA 节点数(1-{max})", + "run_int_cpu_threads": "每个 NUMA 的 CPU 线程数(1-{max})", + "run_int_amx_warning": "⚠ 警告:AMX INT4/INT8 需要兼容的 CPU。检查命令:kt doctor", + "run_int_step4_title": "第 4 步:GPU 专家配置", + "run_int_gpu_experts": "每层 GPU 专家数(0-{max})", + "run_int_gpu_experts_info": "总专家数:{total},每 token 激活:{active}", + "run_int_step5_title": "第 5 步:KV Cache 配置", + "run_int_kv_cache_size": "KV Cache 大小(tokens)", + "run_int_chunk_prefill": "启用分块预填充?", + "run_int_chunk_size": "分块预填充大小(tokens)", + "run_int_gpu_prefill_threshold": "GPU 预填充阈值(tokens)", + "run_int_step6_title": "第 6 步:GPU 选择和张量并行", + "run_int_available_gpus": "可用 GPU:", + "run_int_gpu_id": "GPU {id}", + "run_int_vram_info": "{name}(总计 {total:.1f}GB,空闲 {free:.1f}GB)", + "run_int_select_gpus": "选择 GPU ID(逗号分隔)", + "run_int_invalid_gpu_range": "所有 GPU ID 必须在 0 到 {max} 之间", + "run_int_tp_size": "TP 大小(必须是 2 的幂:1,2,4,8...)", + "run_int_tp_mismatch": "TP 大小必须与选择的 GPU 数量匹配({count})", + "run_int_tp_not_power_of_2": "TP 大小必须是 2 的幂", + "run_int_mem_fraction": "静态内存占用比例(0.0-1.0)", + "run_int_using_saved_mem": "使用已保存的内存占用比例:{fraction}", + "run_int_step7_title": "第 7 步:解析器配置(可选)", + "run_int_tool_call_parser": "工具调用解析器(按回车跳过)", + "run_int_reasoning_parser": "推理解析器(按回车跳过)", + "run_int_step8_title": "第 8 步:主机和端口配置", + "run_int_host": "主机", + "run_int_port": "端口", + "run_int_port_occupied": "⚠ 端口 {port} 已被占用", + "run_int_port_suggestion": "建议使用可用端口:{port}", + "run_int_use_suggested": "使用建议的端口?", + "run_int_saved_configs": "已保存的配置:", + "run_int_config_name": "配置 {num}", + "run_int_kt_method": "KT 方法:", + "run_int_numa_nodes_label": "NUMA 节点:", + "run_int_cpu_threads_label": "CPU 线程:", + "run_int_gpu_experts_label": "GPU 专家:", + "run_int_tp_size_label": "TP 大小:", + "run_int_mem_fraction_label": "内存占用比例:", + "run_int_server_label": "服务器:", + "run_int_kv_cache_label": "KV Cache:", + "run_int_chunk_prefill_label": "分块预填充:", + "run_int_gpu_prefill_label": "GPU 预填充阈值:", + "run_int_tool_parser_label": "工具调用解析器:", + "run_int_reasoning_parser_label": "推理解析器:", + "run_int_command_label": "命令:", + "run_int_select_config": "选择配置", + "run_int_gpu_select_required": "请选择 {tp} 个 GPU(来自已保存配置的 TP 大小)", + "run_int_port_check_title": "端口配置", + "run_int_port_checking": "正在检查端口 {port} 可用性...", + "run_int_port_available": "端口 {port} 可用", + "run_int_saved_config_title": "已保存的配置", + "run_int_save_config_title": "保存配置", + "run_int_save_config_prompt": "保存此配置以供将来使用?", + "run_int_config_name_prompt": "配置名称", + "run_int_config_name_default": "配置 {timestamp}", + "run_int_config_saved": "配置已保存:{name}", + "run_int_config_summary": "配置完成", + "run_int_model_label": "模型:", + "run_int_selected_gpus_label": "已选择的 GPU:", # Model command "model_supported_title": "KTransformers 支持的模型", "model_column_model": "模型", @@ -557,6 +1073,180 @@ MESSAGES: dict[str, dict[str, str]] = { "model_column_name": "名称", "model_column_hf_repo": "HuggingFace 仓库", "model_column_aliases": "别名", + # Model management - new user registry system + "model_no_registered_models": "尚未注册任何模型。", + "model_scan_hint": "扫描模型: kt model scan", + "model_add_hint": "添加模型: kt model add /path/to/model", + "model_registered_models_title": "已注册的模型", + "model_column_format": "格式", + "model_column_repo": "仓库", + "model_column_sha256": "SHA256", + "model_non_moe_hidden_hint": "检测到 {count} 个非MoE模型,使用 kt model list --all 展示全部", + "model_usage_title": "常用操作:", + "model_usage_info": "查看详情:", + "model_usage_edit": "编辑模型:", + "model_usage_verify": "校验权重:", + "model_usage_quant": "量化模型:", + "model_usage_run": "运行模型:", + "model_usage_scan": "扫描模型:", + "model_usage_add": "添加模型:", + "model_usage_verbose": "查看包含文件详情:", + "model_no_storage_paths": "未配置存储路径。", + "model_add_path_hint": "添加存储路径: kt config set model.storage_paths /path/to/models", + "model_scanning_paths": "正在扫描配置的存储路径...", + "model_scanning_progress": "扫描中: {path}", + "model_scan_warnings_title": "警告", + "model_scan_no_models_found": "在配置的路径中未找到模型。", + "model_scan_check_paths_hint": "检查存储路径: kt config get model.storage_paths", + "model_scan_min_size_hint": "文件夹必须 ≥{size}GB 才能被识别为模型。", + "model_scan_found_title": "发现 {count} 个新模型", + "model_column_path": "路径", + "model_column_size": "大小", + "model_scan_auto_adding": "正在自动添加模型...", + "model_added": "已添加: {name}", + "model_add_failed": "添加 {name} 失败: {error}", + "model_scan_complete": "扫描完成!已添加 {count} 个模型。", + "model_scan_interactive_prompt": "命令: edit | del | done", + "model_scan_cmd_edit": "设置模型自定义名称和仓库", + "model_scan_cmd_delete": "跳过此模型", + "model_scan_cmd_done": "完成并添加模型", + "model_scan_marked_skip": "已跳过模型 #{id}", + "model_scan_invalid_id": "无效的模型 ID: {id}", + "model_scan_invalid_command": "无效命令。使用: edit | del | done", + "model_scan_edit_model": "编辑模型 {id}", + "model_scan_edit_note": "您可以在添加到注册表前更改模型名称和配置仓库信息", + "model_scan_adding_models": "正在添加 {count} 个模型...", + "model_scan_next_steps": "后续步骤", + "model_scan_view_hint": "查看已注册模型: kt model list", + "model_scan_edit_hint": "编辑模型详情: kt model edit ", + "model_scan_no_models_added": "未添加任何模型。", + "model_add_path_not_exist": "错误: 路径不存在: {path}", + "model_add_not_directory": "错误: 路径不是目录: {path}", + "model_add_already_registered": "此路径已注册为: {name}", + "model_add_view_hint": "查看: kt model info {name}", + "model_add_scanning": "正在扫描模型文件...", + "model_add_scan_failed": "扫描模型失败: {error}", + "model_add_no_model_files": "在 {path} 中未找到模型文件", + "model_add_supported_formats": "支持: *.safetensors, *.gguf (文件夹 ≥10GB)", + "model_add_detected": "检测到: {format} 格式, {size}, {count} 个文件", + "model_add_name_conflict": "名称 '{name}' 已存在。", + "model_add_prompt_name": "为此模型输入名称", + "model_add_name_exists": "名称已存在。请选择其他名称:", + "model_add_configure_repo": "配置仓库信息以进行 SHA256 验证?", + "model_add_repo_type_prompt": "选择仓库类型:", + "model_add_choice": "选择", + "model_add_repo_id_prompt": "输入仓库 ID (例如: deepseek-ai/DeepSeek-V3)", + "model_add_success": "成功添加模型: {name}", + "model_add_verify_hint": "验证完整性: kt model verify {name}", + "model_add_edit_later_hint": "稍后编辑详情: kt model edit {name}", + "model_add_failed_generic": "添加模型失败: {error}", + "model_edit_not_found": "未找到模型 '{name}'。", + "model_edit_list_hint": "列出模型: kt model list", + "model_edit_current_config": "当前配置", + "model_edit_what_to_edit": "您想编辑什么?", + "model_edit_option_name": "编辑名称", + "model_edit_option_repo": "配置仓库信息", + "model_edit_option_delete": "删除此模型", + "model_edit_option_cancel": "取消 / 退出", + "model_edit_choice_prompt": "选择选项", + "model_edit_new_name": "输入新名称", + "model_edit_name_conflict": "名称 '{name}' 已存在。请选择其他名称:", + "model_edit_name_updated": "名称已更新: {old} → {new}", + "model_edit_repo_type_prompt": "仓库类型 (或按回车删除仓库信息):", + "model_edit_repo_remove": "删除仓库信息", + "model_edit_repo_id_prompt": "输入仓库 ID", + "model_edit_repo_removed": "仓库信息已删除", + "model_edit_repo_updated": "仓库已配置: {repo_type} → {repo_id}", + "model_edit_delete_warning": "从注册表中删除模型 '{name}'?", + "model_edit_delete_note": "注意: 这只会删除注册表条目。{path} 中的模型文件不会被删除。", + "model_edit_delete_confirm": "确认删除?", + "model_edit_deleted": "模型 '{name}' 已从注册表中删除", + "model_edit_delete_cancelled": "删除已取消", + "model_edit_cancelled": "编辑已取消", + # Model edit - Interactive selection + "model_edit_select_title": "选择要编辑的模型", + "model_edit_select_model": "选择模型", + "model_edit_invalid_choice": "无效选择", + "model_edit_no_models": "注册表中未找到模型。", + "model_edit_add_hint_scan": "添加模型:", + "model_edit_add_hint_add": "或:", + # Model edit - Display + "model_edit_gpu_links": "GPU 链接:", + # Model edit - Menu options + "model_edit_manage_gpu_links": "管理 GPU 链接", + "model_edit_save_changes": "保存更改", + "model_edit_has_changes": "(有更改)", + "model_edit_no_changes": "(无更改)", + # Model edit - Pending changes messages + "model_edit_name_pending": "名称将在保存更改时更新。", + "model_edit_repo_remove_pending": "仓库信息将在保存更改时删除。", + "model_edit_repo_update_pending": "仓库信息将在保存更改时更新。", + # Model edit - GPU link management + "model_edit_gpu_links_title": "管理 {name} 的 GPU 链接", + "model_edit_current_gpu_links": "当前 GPU 链接:", + "model_edit_no_gpu_links": "未配置 GPU 链接。", + "model_edit_gpu_options": "选项:", + "model_edit_gpu_add": "添加 GPU 链接", + "model_edit_gpu_remove": "删除 GPU 链接", + "model_edit_gpu_clear": "清除所有 GPU 链接", + "model_edit_gpu_back": "返回主菜单", + "model_edit_gpu_choose_option": "选择选项", + "model_edit_gpu_none_available": "没有可链接的 GPU 模型。", + "model_edit_gpu_available_models": "可用的 GPU 模型:", + "model_edit_gpu_already_linked": "(已链接)", + "model_edit_gpu_enter_number": "输入要添加的 GPU 模型编号", + "model_edit_gpu_link_pending": "GPU 链接将在保存更改时添加: {name}", + "model_edit_gpu_already_exists": "此 GPU 模型已链接。", + "model_edit_gpu_invalid_choice": "无效选择。", + "model_edit_gpu_invalid_input": "无效输入。", + "model_edit_gpu_none_to_remove": "没有可删除的 GPU 链接。", + "model_edit_gpu_choose_to_remove": "选择要删除的 GPU 链接:", + "model_edit_gpu_enter_to_remove": "输入要删除的编号", + "model_edit_gpu_remove_pending": "GPU 链接将在保存更改时删除: {name}", + "model_edit_gpu_none_to_clear": "没有可清除的 GPU 链接。", + "model_edit_gpu_clear_confirm": "删除所有 GPU 链接?", + "model_edit_gpu_clear_pending": "所有 GPU 链接将在保存更改时删除。", + "model_edit_cancelled_short": "已取消。", + # Model edit - Save operation + "model_edit_no_changes_to_save": "没有更改可保存。", + "model_edit_saving": "正在保存更改...", + "model_edit_saved": "更改保存成功!", + "model_edit_updated_config": "更新后的配置:", + "model_edit_repo_changed_warning": "⚠ 仓库信息已更改。", + "model_edit_verify_hint": "运行 [cyan]kt model verify[/cyan] 以使用 SHA256 校验和验证模型完整性。", + "model_edit_discard_changes": "放弃未保存的更改?", + "model_info_not_found": "未找到模型 '{name}'。", + "model_info_list_hint": "列出所有模型: kt model list", + "model_remove_not_found": "未找到模型 '{name}'。", + "model_remove_list_hint": "列出模型: kt model list", + "model_remove_warning": "从注册表中删除模型 '{name}'?", + "model_remove_note": "注意: 这只会删除注册表条目。模型文件不会从 {path} 中删除。", + "model_remove_confirm": "确认删除?", + "model_remove_cancelled": "删除已取消", + "model_removed": "模型 '{name}' 已从注册表中删除", + "model_remove_failed": "删除模型失败: {error}", + "model_refresh_checking": "正在检查模型路径...", + "model_refresh_all_valid": "所有模型都有效! (已检查 {count} 个模型)", + "model_refresh_total": "总模型数: {total}", + "model_refresh_missing_found": "发现 {count} 个缺失的模型", + "model_refresh_suggestions": "建议操作", + "model_refresh_remove_hint": "从注册表中删除: kt model remove ", + "model_refresh_rescan_hint": "重新扫描模型: kt model scan", + "model_verify_not_found": "未找到模型 '{name}'。", + "model_verify_list_hint": "列出模型: kt model list", + "model_verify_no_repo": "模型 '{name}' 未配置仓库信息。", + "model_verify_config_hint": "配置仓库: kt model edit {name}", + "model_verify_path_missing": "模型路径不存在: {path}", + "model_verify_starting": "正在验证模型完整性...", + "model_verify_progress": "仓库: {repo_type} → {repo_id}", + "model_verify_not_implemented": "SHA256 验证尚未实现", + "model_verify_future_note": "此功能将从 {repo_type} 获取官方 SHA256 哈希值并与本地文件进行比较。", + "model_verify_passed": "验证通过!所有文件都与官方哈希匹配。", + "model_verify_failed": "验证失败!{count} 个文件的哈希不匹配。", + "model_verify_all_no_repos": "没有模型配置了仓库信息。", + "model_verify_all_config_hint": "配置仓库使用: kt model edit ", + "model_verify_all_found": "发现 {count} 个配置了仓库信息的模型", + "model_verify_all_manual_hint": "验证特定模型: kt model verify ", # Coming soon "feature_coming_soon": "此功能即将推出...", }, diff --git a/kt-kernel/python/cli/main.py b/kt-kernel/python/cli/main.py index 43449f9..0f368a1 100644 --- a/kt-kernel/python/cli/main.py +++ b/kt-kernel/python/cli/main.py @@ -5,6 +5,10 @@ KTransformers CLI - A unified command-line interface for KTransformers. """ import sys +import warnings + +# Suppress numpy subnormal warnings +warnings.filterwarnings("ignore", message="The value of the smallest subnormal") import typer @@ -28,6 +32,7 @@ def _get_help(key: str) -> str: "run": {"en": "Start model inference server", "zh": "启动模型推理服务器"}, "chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"}, "quant": {"en": "Quantize model weights", "zh": "量化模型权重"}, + "edit": {"en": "Edit model information", "zh": "编辑模型信息"}, "bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"}, "microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"}, "doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"}, @@ -43,7 +48,7 @@ def _get_help(key: str) -> str: app = typer.Typer( name="kt", help="KTransformers CLI - A unified command-line interface for KTransformers.", - no_args_is_help=True, + no_args_is_help=False, # Handle no-args case manually to support first-run setup add_completion=False, # Use static completion scripts instead of dynamic completion rich_markup_mode="rich", ) @@ -66,20 +71,7 @@ def _update_help_texts() -> None: group_info.help = _get_help(group_info.name) -# Register commands -app.command(name="version", help="Show version information")(version.version) -# Run command is handled specially in main() to allow extra args -# (not registered here to avoid typer's argument parsing) -app.command(name="chat", help="Interactive chat with running model")(chat.chat) -app.command(name="quant", help="Quantize model weights")(quant.quant) -app.command(name="bench", help="Run full benchmark")(bench.bench) -app.command(name="microbench", help="Run micro-benchmark")(bench.microbench) -app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor) - -# Register sub-apps -app.add_typer(model.app, name="model", help="Manage models and storage paths") -app.add_typer(config.app, name="config", help="Manage configuration") -app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory") +# Commands are registered later after tui_command is defined def check_first_run() -> None: @@ -116,7 +108,7 @@ def _show_first_run_setup(settings) -> None: from rich.spinner import Spinner from rich.live import Live - from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location + from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb console = Console() @@ -140,15 +132,8 @@ def _show_first_run_setup(settings) -> None: console.print(" [cyan][2][/cyan] 中文 (Chinese)") console.print() - while True: - choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1") - - if choice == "1": - lang = "en" - break - elif choice == "2": - lang = "zh" - break + choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1") + lang = "en" if choice == "1" else "zh" # Save language setting settings.set("general.language", lang) @@ -161,6 +146,131 @@ def _show_first_run_setup(settings) -> None: else: console.print("[green]✓[/green] Language set to English") + # Model discovery section + console.print() + if lang == "zh": + console.print("[bold]发现模型权重[/bold]") + console.print() + console.print("[dim]扫描系统中已有的模型权重文件,以便快速添加到模型列表。[/dim]") + console.print() + console.print(" [cyan][1][/cyan] 全局扫描 (自动扫描所有非系统路径)") + console.print(" [cyan][2][/cyan] 手动指定路径 (可添加多个)") + console.print(" [cyan][3][/cyan] 跳过 (稍后手动添加)") + console.print() + scan_choice = Prompt.ask("选择扫描方式", choices=["1", "2", "3"], default="1") + else: + console.print("[bold]Discover Model Weights[/bold]") + console.print() + console.print("[dim]Scan existing model weights on your system to quickly add them to the model list.[/dim]") + console.print() + console.print(" [cyan][1][/cyan] Global scan (auto-scan all non-system paths)") + console.print(" [cyan][2][/cyan] Manual paths (add multiple paths)") + console.print(" [cyan][3][/cyan] Skip (add manually later)") + console.print() + scan_choice = Prompt.ask("Select scan method", choices=["1", "2", "3"], default="1") + + if scan_choice == "1": + # Global scan + from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary + + console.print() + try: + total_found, new_found, registered = discover_and_register_global( + min_size_gb=2.0, max_depth=6, show_progress=True, lang=lang + ) + + format_discovery_summary( + total_found=total_found, + new_found=new_found, + registered=registered, + lang=lang, + show_models=True, + max_show=10, + ) + + except Exception as e: + console.print(f"[yellow]Warning: Scan failed - {e}[/yellow]") + + elif scan_choice == "2": + # Manual path specification + from kt_kernel.cli.utils.model_discovery import discover_and_register_path + import os + + discovered_paths = set() # Track paths discovered in this session + total_registered = [] + + while True: + console.print() + if lang == "zh": + path = Prompt.ask("输入要扫描的路径 (例如: /mnt/data/models)") + else: + path = Prompt.ask("Enter path to scan (e.g., /mnt/data/models)") + + # Expand and validate path + path = os.path.expanduser(path) + + if not os.path.exists(path): + if lang == "zh": + console.print(f"[yellow]警告: 路径不存在: {path}[/yellow]") + else: + console.print(f"[yellow]Warning: Path does not exist: {path}[/yellow]") + continue + + if not os.path.isdir(path): + if lang == "zh": + console.print(f"[yellow]警告: 不是一个目录: {path}[/yellow]") + else: + console.print(f"[yellow]Warning: Not a directory: {path}[/yellow]") + continue + + # Scan this path + console.print() + try: + total_found, new_found, registered = discover_and_register_path( + path=path, min_size_gb=2.0, existing_paths=discovered_paths, show_progress=True, lang=lang + ) + + # Update discovered paths + for model in registered: + discovered_paths.add(model.path) + total_registered.extend(registered) + + console.print() + if lang == "zh": + console.print(f"[green]✓[/green] 在此路径找到 {total_found} 个模型,其中 {new_found} 个为新模型") + else: + console.print(f"[green]✓[/green] Found {total_found} models in this path, {new_found} are new") + + if new_found > 0: + for model in registered[:5]: + console.print(f" • {model.name} ({model.format})") + + if len(registered) > 5: + if lang == "zh": + console.print(f" [dim]... 还有 {len(registered) - 5} 个新模型[/dim]") + else: + console.print(f" [dim]... and {len(registered) - 5} more new models[/dim]") + + except Exception as e: + console.print(f"[red]Error scanning path: {e}[/red]") + + # Ask if continue + console.print() + if lang == "zh": + continue_scan = Confirm.ask("是否继续添加其他路径?", default=False) + else: + continue_scan = Confirm.ask("Continue adding more paths?", default=False) + + if not continue_scan: + break + + if total_registered: + console.print() + if lang == "zh": + console.print(f"[green]✓[/green] 总共发现 {len(total_registered)} 个新模型") + else: + console.print(f"[green]✓[/green] Total {len(total_registered)} new models discovered") + # Model storage path selection console.print() console.print(f"[bold]{t('setup_model_path_title')}[/bold]") @@ -174,16 +284,7 @@ def _show_first_run_setup(settings) -> None: console.print() if locations: - # Scan for models in each location - console.print(f"[dim]{t('setup_scanning_models')}[/dim]") - location_models: dict[str, list] = {} - for loc in locations[:5]: - models = scan_models_in_location(loc, max_depth=2) - if models: - location_models[loc.path] = models - console.print() - - # Show options + # Show storage location options for i, loc in enumerate(locations[:5], 1): # Show top 5 options available = format_size_gb(loc.available_gb) total = format_size_gb(loc.total_gb) @@ -194,22 +295,8 @@ def _show_first_run_setup(settings) -> None: else: option_str = t("setup_disk_option", path=loc.path, available=available, total=total) - # Add model count if any - if loc.path in location_models: - model_count = len(location_models[loc.path]) - option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]" - console.print(f" [cyan][{i}][/cyan] {option_str}") - # Show first few models found in this location - if loc.path in location_models: - for model in location_models[loc.path][:3]: # Show up to 3 models - size_str = format_size_gb(model.size_gb) - console.print(f" [dim]• {model.name} ({size_str})[/dim]") - if len(location_models[loc.path]) > 3: - remaining = len(location_models[loc.path]) - 3 - console.print(f" [dim] ... +{remaining} more[/dim]") - # Custom path option custom_idx = min(len(locations), 5) + 1 console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}") @@ -323,51 +410,28 @@ def _install_shell_completion() -> None: # Detect current shell shell = os.environ.get("SHELL", "") - if "zsh" in shell: - shell_name = "zsh" - elif "fish" in shell: - shell_name = "fish" - else: - shell_name = "bash" + shell_name = "zsh" if "zsh" in shell else "fish" if "fish" in shell else "bash" try: cli_dir = Path(__file__).parent completions_dir = cli_dir / "completions" home = Path.home() - installed = False + def install_completion(src_name: str, dest_dir: Path, dest_name: str) -> None: + """Install completion file from source to destination.""" + src_file = completions_dir / src_name + if src_file.exists(): + dest_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_file, dest_dir / dest_name) if shell_name == "bash": - # Use XDG standard location for bash-completion (auto-loaded) - src_file = completions_dir / "kt-completion.bash" - dest_dir = home / ".local" / "share" / "bash-completion" / "completions" - dest_file = dest_dir / "kt" - - if src_file.exists(): - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(src_file, dest_file) - installed = True - + install_completion( + "kt-completion.bash", home / ".local" / "share" / "bash-completion" / "completions", "kt" + ) elif shell_name == "zsh": - src_file = completions_dir / "_kt" - dest_dir = home / ".zfunc" - dest_file = dest_dir / "_kt" - - if src_file.exists(): - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(src_file, dest_file) - installed = True - + install_completion("_kt", home / ".zfunc", "_kt") elif shell_name == "fish": - # Fish auto-loads from this directory - src_file = completions_dir / "kt.fish" - dest_dir = home / ".config" / "fish" / "completions" - dest_file = dest_dir / "kt.fish" - - if src_file.exists(): - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(src_file, dest_file) - installed = True + install_completion("kt.fish", home / ".config" / "fish" / "completions", "kt.fish") # Mark as installed settings.set("general._completion_installed", True) @@ -403,6 +467,20 @@ def _apply_saved_language() -> None: set_lang(lang) +app.command(name="version", help="Show version information")(version.version) +app.command(name="chat", help="Interactive chat with running model")(chat.chat) +app.command(name="quant", help="Quantize model weights")(quant.quant) +app.command(name="edit", help="Edit model information")(model.edit_model) +app.command(name="bench", help="Run full benchmark")(bench.bench) +app.command(name="microbench", help="Run micro-benchmark")(bench.microbench) +app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor) + +# Register sub-apps +app.add_typer(model.app, name="model", help="Manage models and storage paths") +app.add_typer(config.app, name="config", help="Manage configuration") +app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory") + + def main(): """Main entry point.""" # Apply saved language setting first (before anything else for correct help display) @@ -414,7 +492,7 @@ def main(): # Check for first run (but not for certain commands) # Skip first-run check for: --help, config commands, version args = sys.argv[1:] if len(sys.argv) > 1 else [] - skip_commands = ["--help", "-h", "config", "version", "--version"] + skip_commands = ["--help", "-h", "config", "version", "--version", "--no-tui"] should_check_first_run = True for arg in args: @@ -422,12 +500,35 @@ def main(): should_check_first_run = False break + # Handle no arguments case + if not args: + # Check if this is first run + from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE, get_settings + + is_first_run = False + if not DEFAULT_CONFIG_FILE.exists(): + is_first_run = True + else: + settings = get_settings() + if not settings.get("general._initialized"): + is_first_run = True + + if is_first_run: + # First run - start initialization + _install_shell_completion() + check_first_run() + return + else: + # Not first run - show help + app(["--help"]) + return + # Auto-install shell completion on first run if should_check_first_run: _install_shell_completion() # Check first run before running commands - if should_check_first_run and args: + if should_check_first_run: check_first_run() # Handle "run" command specially to pass through unknown options diff --git a/kt-kernel/python/cli/utils/analyze_moe_model.py b/kt-kernel/python/cli/utils/analyze_moe_model.py new file mode 100644 index 0000000..ecc5778 --- /dev/null +++ b/kt-kernel/python/cli/utils/analyze_moe_model.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +""" +快速分析 MoE 模型 - 基于 config.json +(复用 sglang 的模型注册表和判断逻辑) +""" +import json +import hashlib +from pathlib import Path +from typing import Optional, Dict, Any + + +def _get_sglang_moe_architectures(): + """ + 从 sglang 的模型注册表获取所有 MoE 架构 + + 复用 sglang 的代码,这样 sglang 更新后自动支持新模型 + """ + try: + import sys + + # 添加 sglang 路径到 sys.path + sglang_path = Path("/mnt/data2/ljq/sglang/python") + if sglang_path.exists() and str(sglang_path) not in sys.path: + sys.path.insert(0, str(sglang_path)) + + # 直接导入 sglang 的 ModelRegistry + # 注意:这需要 sglang 及其依赖正确安装 + from sglang.srt.models.registry import ModelRegistry + + # 获取所有支持的架构 + supported_archs = ModelRegistry.get_supported_archs() + + # 过滤出 MoE 模型(名称包含 Moe) + moe_archs = {arch for arch in supported_archs if "Moe" in arch or "moe" in arch.lower()} + + # 手动添加一些不带 "Moe" 字样但是 MoE 模型的架构 + # DeepSeek V2/V3 系列 + deepseek_moe = {arch for arch in supported_archs if arch.startswith("Deepseek") or arch.startswith("deepseek")} + moe_archs.update(deepseek_moe) + + # DBRX 也是 MoE 模型 + dbrx_moe = {arch for arch in supported_archs if "DBRX" in arch or "dbrx" in arch.lower()} + moe_archs.update(dbrx_moe) + + # Grok 也是 MoE 模型 + grok_moe = {arch for arch in supported_archs if "Grok" in arch or "grok" in arch.lower()} + moe_archs.update(grok_moe) + + return moe_archs + except Exception as e: + # 如果 sglang 不可用,返回空集合 + # 这种情况下,后续会使用配置文件中的其他判断方法 + import warnings + + warnings.warn(f"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.") + return set() + + +# 获取 MoE 架构列表(优先从 sglang 获取) +MOE_ARCHITECTURES = _get_sglang_moe_architectures() + + +def _get_cache_file(): + """获取集中式缓存文件路径""" + cache_dir = Path.home() / ".ktransformers" / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / "moe_analysis_v2.json" + + +def _load_all_cache(): + """加载所有缓存数据""" + cache_file = _get_cache_file() + if not cache_file.exists(): + return {} + + try: + with open(cache_file, "r") as f: + return json.load(f) + except Exception: + return {} + + +def _save_all_cache(cache_data): + """保存所有缓存数据""" + cache_file = _get_cache_file() + try: + with open(cache_file, "w") as f: + json.dump(cache_data, f, indent=2) + except Exception as e: + import warnings + + warnings.warn(f"Failed to save MoE cache: {e}") + + +def _compute_config_fingerprint(config_path: Path) -> Optional[str]: + """计算 config.json 指纹""" + if not config_path.exists(): + return None + + try: + stat = config_path.stat() + # 使用文件大小和修改时间作为指纹 + fingerprint_str = f"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}" + return hashlib.md5(fingerprint_str.encode()).hexdigest() + except Exception: + return None + + +def _load_cache(model_path: Path) -> Optional[Dict[str, Any]]: + """加载指定模型的缓存""" + model_path_str = str(model_path.resolve()) + all_cache = _load_all_cache() + + if model_path_str not in all_cache: + return None + + try: + cache_entry = all_cache[model_path_str] + + # 验证缓存版本 + cache_version = cache_entry.get("cache_version", 0) + if cache_version != 2: + return None + + # 验证 config.json 指纹 + config_path = model_path / "config.json" + current_fingerprint = _compute_config_fingerprint(config_path) + if cache_entry.get("fingerprint") != current_fingerprint: + return None + + return cache_entry.get("result") + except Exception: + return None + + +def _save_cache(model_path: Path, result: Dict[str, Any]): + """保存指定模型的缓存""" + model_path_str = str(model_path.resolve()) + + try: + config_path = model_path / "config.json" + fingerprint = _compute_config_fingerprint(config_path) + + all_cache = _load_all_cache() + + all_cache[model_path_str] = { + "fingerprint": fingerprint, + "result": result, + "cache_version": 2, + "last_updated": __import__("datetime").datetime.now().isoformat(), + } + + _save_all_cache(all_cache) + except Exception as e: + import warnings + + warnings.warn(f"Failed to save MoE cache for {model_path}: {e}") + + +def _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]: + """读取 config.json 文件 + + 参考 sglang 的 get_config() 实现 + """ + config_path = model_path / "config.json" + + if not config_path.exists(): + return None + + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + return config + except Exception: + return None + + +def _is_moe_model(config: Dict[str, Any]) -> bool: + """判断是否是 MoE 模型 + + 参考 sglang 的模型注册表和架构识别方式 + """ + # 方法1: 检查架构名称 + architectures = config.get("architectures", []) + if any(arch in MOE_ARCHITECTURES for arch in architectures): + return True + + # 方法2: 检查是否有 MoE 相关字段(Mistral 格式) + if config.get("moe"): + return True + + # 方法3: 检查是否有 num_experts 或其变体字段 + # 需要检查 text_config(对于某些多模态模型) + text_config = config.get("text_config", config) + + # 检查各种专家数量字段 + if ( + text_config.get("num_experts") or text_config.get("num_local_experts") or text_config.get("n_routed_experts") + ): # Kimi-K2 使用这个字段 + return True + + return False + + +def _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]: + """从 config 中提取 MoE 参数 + + 参考 sglang 的各种 MoE 模型实现 + """ + # 处理嵌套的 text_config + text_config = config.get("text_config", config) + + # 提取基本参数 + result = { + "architectures": config.get("architectures", []), + "model_type": config.get("model_type", "unknown"), + } + + # 专家数量(不同模型字段名不同) + num_experts = ( + text_config.get("num_experts") # Qwen2/3 MoE, DeepSeek V2 + or text_config.get("num_local_experts") # Mixtral + or text_config.get("n_routed_experts") # Kimi-K2, DeepSeek V3 + or config.get("moe", {}).get("num_experts") # Mistral 格式 + ) + + # 每个 token 激活的专家数 + num_experts_per_tok = ( + text_config.get("num_experts_per_tok") + or text_config.get("num_experts_per_token") + or config.get("moe", {}).get("num_experts_per_tok") + or 2 # 默认值 + ) + + # 层数 + num_hidden_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or 0 + + # 隐藏层维度 + hidden_size = text_config.get("hidden_size") or text_config.get("d_model") or 0 + + # MoE 专家中间层大小 + moe_intermediate_size = ( + text_config.get("moe_intermediate_size") + or text_config.get("intermediate_size") # 如果没有特殊的 moe_intermediate_size + or 0 + ) + + # 共享专家中间层大小(Qwen2/3 MoE) + shared_expert_intermediate_size = text_config.get("shared_expert_intermediate_size", 0) + + result.update( + { + "num_experts": num_experts or 0, + "num_experts_per_tok": num_experts_per_tok, + "num_hidden_layers": num_hidden_layers, + "hidden_size": hidden_size, + "moe_intermediate_size": moe_intermediate_size, + "shared_expert_intermediate_size": shared_expert_intermediate_size, + } + ) + + # 提取其他有用的参数 + result["num_attention_heads"] = text_config.get("num_attention_heads", 0) + result["num_key_value_heads"] = text_config.get("num_key_value_heads", 0) + result["vocab_size"] = text_config.get("vocab_size", 0) + result["max_position_embeddings"] = text_config.get("max_position_embeddings", 0) + + return result + + +def _estimate_model_size(model_path: Path) -> float: + """估算模型总大小(GB) + + 快速统计 safetensors 文件总大小 + """ + try: + total_size = 0 + for file_path in model_path.glob("*.safetensors"): + total_size += file_path.stat().st_size + return total_size / (1024**3) + except Exception: + return 0.0 + + +def analyze_moe_model(model_path, use_cache=True): + """ + 快速分析 MoE 模型 - 只读取 config.json + + 参数: + model_path: 模型路径(字符串或Path对象) + use_cache: 是否使用缓存(默认True) + + 返回: + dict: { + 'is_moe': 是否是 MoE 模型, + 'num_experts': 专家总数, + 'num_experts_per_tok': 每个 token 激活的专家数, + 'num_hidden_layers': 层数, + 'hidden_size': 隐藏层维度, + 'moe_intermediate_size': MoE 专家中间层大小, + 'shared_expert_intermediate_size': 共享专家中间层大小, + 'architectures': 模型架构列表, + 'model_type': 模型类型, + 'total_size_gb': 模型总大小(估算,GB), + 'cached': 是否从缓存读取 + } + 如果不是 MoE 模型或失败,返回 None + """ + model_path = Path(model_path) + + if not model_path.exists(): + return None + + # 尝试加载缓存 + if use_cache: + cached_result = _load_cache(model_path) + if cached_result: + cached_result["cached"] = True + return cached_result + + # 读取 config.json + config = _load_config_json(model_path) + if not config: + return None + + # 判断是否是 MoE 模型 + if not _is_moe_model(config): + return None + + # 提取 MoE 参数 + params = _extract_moe_params(config) + + # 验证必要参数 + if params["num_experts"] == 0: + return None + + # 估算模型大小 + total_size_gb = _estimate_model_size(model_path) + + # 组装结果 + result = { + "is_moe": True, + "num_experts": params["num_experts"], + "num_experts_per_tok": params["num_experts_per_tok"], + "num_hidden_layers": params["num_hidden_layers"], + "hidden_size": params["hidden_size"], + "moe_intermediate_size": params["moe_intermediate_size"], + "shared_expert_intermediate_size": params["shared_expert_intermediate_size"], + "architectures": params["architectures"], + "model_type": params["model_type"], + "total_size_gb": total_size_gb, + "cached": False, + # 额外参数 + "num_attention_heads": params.get("num_attention_heads", 0), + "num_key_value_heads": params.get("num_key_value_heads", 0), + "vocab_size": params.get("vocab_size", 0), + } + + # 保存缓存 + if use_cache: + _save_cache(model_path, result) + + return result + + +def print_analysis(model_path): + """打印模型分析结果""" + print(f"分析模型: {model_path}\n") + + result = analyze_moe_model(model_path) + + if result is None: + print("不是 MoE 模型或分析失败") + return + + print("=" * 70) + print("MoE 模型分析结果") + if result.get("cached"): + print("[使用缓存]") + print("=" * 70) + print(f"模型架构:") + print(f" - 架构: {', '.join(result['architectures'])}") + print(f" - 类型: {result['model_type']}") + print() + print(f"MoE 结构:") + print(f" - 专家总数: {result['num_experts']}") + print(f" - 激活专家数: {result['num_experts_per_tok']} experts/token") + print(f" - 层数: {result['num_hidden_layers']}") + print(f" - 隐藏维度: {result['hidden_size']}") + print(f" - MoE 中间层: {result['moe_intermediate_size']}") + if result["shared_expert_intermediate_size"] > 0: + print(f" - 共享专家中间层: {result['shared_expert_intermediate_size']}") + print() + print(f"大小统计:") + print(f" - 模型总大小: {result['total_size_gb']:.2f} GB") + print("=" * 70) + print() + + +def main(): + import sys + + models = ["/mnt/data2/models/Qwen3-30B-A3B", "/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507"] + + if len(sys.argv) > 1: + models = [sys.argv[1]] + + for model_path in models: + print_analysis(model_path) + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/python/cli/utils/debug_configs.py b/kt-kernel/python/cli/utils/debug_configs.py new file mode 100644 index 0000000..3f30b9a --- /dev/null +++ b/kt-kernel/python/cli/utils/debug_configs.py @@ -0,0 +1,118 @@ +""" +Debug utility to inspect saved run configurations. + +Usage: python -m kt_kernel.cli.utils.debug_configs +""" + +from pathlib import Path +import yaml +from rich.console import Console +from rich.table import Table +from rich import box + +console = Console() + + +def main(): + """Show all saved configurations.""" + config_file = Path.home() / ".ktransformers" / "run_configs.yaml" + + console.print() + console.print(f"[bold]Configuration file:[/bold] {config_file}") + console.print() + + if not config_file.exists(): + console.print("[red]✗ Configuration file does not exist![/red]") + console.print() + console.print("No configurations have been saved yet.") + return + + try: + with open(config_file, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + except Exception as e: + console.print(f"[red]✗ Failed to load configuration file: {e}[/red]") + return + + console.print(f"[green]✓[/green] Configuration file loaded") + console.print() + + configs = data.get("configs", {}) + + if not configs: + console.print("[yellow]No saved configurations found.[/yellow]") + return + + console.print(f"[bold]Found configurations for {len(configs)} model(s):[/bold]") + console.print() + + for model_id, model_configs in configs.items(): + console.print(f"[cyan]Model ID:[/cyan] {model_id}") + console.print(f"[dim] {len(model_configs)} configuration(s)[/dim]") + console.print() + + if not model_configs: + continue + + # Display configs in a table + table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan") + table.add_column("#", justify="right", style="cyan") + table.add_column("Name", style="white") + table.add_column("Method", style="yellow") + table.add_column("TP", justify="right", style="green") + table.add_column("GPU Experts", justify="right", style="magenta") + table.add_column("Created", style="dim") + + for i, cfg in enumerate(model_configs, 1): + method = cfg.get("inference_method", "?") + kt_method = cfg.get("kt_method", "?") + method_display = f"{method.upper()}" + if method == "raw": + method_display += f" ({cfg.get('raw_method', '?')})" + elif method == "amx": + method_display += f" ({kt_method})" + + table.add_row( + str(i), + cfg.get("config_name", f"Config {i}"), + method_display, + str(cfg.get("tp_size", "?")), + str(cfg.get("gpu_experts", "?")), + cfg.get("created_at", "Unknown")[:19] if cfg.get("created_at") else "Unknown", + ) + + console.print(table) + console.print() + + # Also check user_models.yaml to show model names + console.print("[bold]Checking model registry...[/bold]") + console.print() + + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + try: + registry = UserModelRegistry() + all_models = registry.list_models() + + console.print(f"[green]✓[/green] Found {len(all_models)} registered model(s)") + console.print() + + # Map model IDs to names + id_to_name = {m.id: m.name for m in all_models} + + console.print("[bold]Model ID → Name mapping:[/bold]") + console.print() + + for model_id in configs.keys(): + model_name = id_to_name.get(model_id, "[red]Unknown (model not found in registry)[/red]") + console.print(f" {model_id[:8]}... → {model_name}") + + console.print() + + except Exception as e: + console.print(f"[yellow]⚠ Could not load model registry: {e}[/yellow]") + console.print() + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/python/cli/utils/download_helper.py b/kt-kernel/python/cli/utils/download_helper.py new file mode 100644 index 0000000..0644a71 --- /dev/null +++ b/kt-kernel/python/cli/utils/download_helper.py @@ -0,0 +1,146 @@ +"""Helper functions for interactive model download.""" + +from pathlib import Path +from typing import Dict, List, Tuple +import fnmatch + + +def list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]: + """ + List files in a HuggingFace repository. + + Returns: + List of dicts with keys: 'path', 'size' (in bytes) + """ + from huggingface_hub import HfApi + import os + + # Set mirror if needed + original_endpoint = os.environ.get("HF_ENDPOINT") + if use_mirror and not original_endpoint: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + + try: + api = HfApi() + files_info = api.list_repo_tree(repo_id=repo_id, recursive=True) + + result = [] + for item in files_info: + # Skip directories + if hasattr(item, "type") and item.type == "directory": + continue + + # Get file info + file_path = item.path if hasattr(item, "path") else str(item) + file_size = item.size if hasattr(item, "size") else 0 + + result.append({"path": file_path, "size": file_size}) + + return result + finally: + # Restore original endpoint + if use_mirror and not original_endpoint: + os.environ.pop("HF_ENDPOINT", None) + elif original_endpoint: + os.environ["HF_ENDPOINT"] = original_endpoint + + +def list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]: + """ + List files in a ModelScope repository. + + Returns: + List of dicts with keys: 'path', 'size' (in bytes) + """ + from modelscope.hub.api import HubApi + + api = HubApi() + files_info = api.get_model_files(model_id=repo_id, recursive=True) + + result = [] + for file_info in files_info: + file_path = file_info.get("Name", file_info.get("Path", "")) + file_size = file_info.get("Size", 0) + + result.append({"path": file_path, "size": file_size}) + + return result + + +def filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]: + """Filter files by glob pattern.""" + if pattern == "*": + return files + + filtered = [] + for file in files: + # Check if filename matches pattern + filename = Path(file["path"]).name + full_path = file["path"] + + if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern): + filtered.append(file) + + return filtered + + +def calculate_total_size(files: List[Dict[str, any]]) -> int: + """Calculate total size of files in bytes.""" + return sum(f["size"] for f in files) + + +def format_file_list_table(files: List[Dict[str, any]], max_display: int = 10): + """Format file list as a table for display.""" + from rich.table import Table + from kt_kernel.cli.utils.model_scanner import format_size + + table = Table(show_header=True, header_style="bold") + table.add_column("File", style="cyan", overflow="fold") + table.add_column("Size", justify="right") + + # Show first max_display files + for file in files[:max_display]: + table.add_row(file["path"], format_size(file["size"])) + + if len(files) > max_display: + table.add_row(f"... and {len(files) - max_display} more files", "[dim]...[/dim]") + + return table + + +def verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]: + """ + Verify if a repository exists. + + Returns: + (exists: bool, message: str) + """ + try: + if repo_type == "huggingface": + import os + + original_endpoint = os.environ.get("HF_ENDPOINT") + if use_mirror and not original_endpoint: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + + from huggingface_hub import HfApi + + try: + api = HfApi() + api.repo_info(repo_id=repo_id, repo_type="model") + return True, "Repository found" + finally: + if use_mirror and not original_endpoint: + os.environ.pop("HF_ENDPOINT", None) + elif original_endpoint: + os.environ["HF_ENDPOINT"] = original_endpoint + + else: # modelscope + from modelscope.hub.api import HubApi + + api = HubApi() + api.get_model(model_id=repo_id) + return True, "Repository found" + + except Exception as e: + return False, f"Repository not found: {str(e)}" diff --git a/kt-kernel/python/cli/utils/input_validators.py b/kt-kernel/python/cli/utils/input_validators.py new file mode 100644 index 0000000..b538e35 --- /dev/null +++ b/kt-kernel/python/cli/utils/input_validators.py @@ -0,0 +1,216 @@ +""" +Input validation utilities with retry mechanism. + +Provides robust input validation with automatic retry on failure. +""" + +from typing import Optional, List, Callable, Any +from rich.console import Console +from rich.prompt import Prompt + +console = Console() + + +def prompt_int_with_retry( + message: str, + default: Optional[int] = None, + min_val: Optional[int] = None, + max_val: Optional[int] = None, + validator: Optional[Callable[[int], bool]] = None, + validator_error_msg: Optional[str] = None, +) -> int: + """Prompt for integer input with validation and retry. + + Args: + message: Prompt message + default: Default value (optional) + min_val: Minimum allowed value (optional) + max_val: Maximum allowed value (optional) + validator: Custom validation function (optional) + validator_error_msg: Error message for custom validator (optional) + + Returns: + Validated integer value + """ + while True: + # Build prompt with default + if default is not None: + prompt_text = f"{message} [{default}]" + else: + prompt_text = message + + # Get input + user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None) + + # Try to parse as integer + try: + value = int(user_input) + except ValueError: + console.print(f"[red]✗ Invalid input. Please enter a valid integer.[/red]") + console.print() + continue + + # Validate range + if min_val is not None and value < min_val: + console.print(f"[red]✗ Value must be at least {min_val}[/red]") + console.print() + continue + + if max_val is not None and value > max_val: + console.print(f"[red]✗ Value must be at most {max_val}[/red]") + console.print() + continue + + # Custom validation + if validator is not None: + if not validator(value): + error_msg = validator_error_msg or "Invalid value" + console.print(f"[red]✗ {error_msg}[/red]") + console.print() + continue + + # All validations passed + return value + + +def prompt_float_with_retry( + message: str, + default: Optional[float] = None, + min_val: Optional[float] = None, + max_val: Optional[float] = None, +) -> float: + """Prompt for float input with validation and retry. + + Args: + message: Prompt message + default: Default value (optional) + min_val: Minimum allowed value (optional) + max_val: Maximum allowed value (optional) + + Returns: + Validated float value + """ + while True: + # Build prompt with default + if default is not None: + prompt_text = f"{message} [{default}]" + else: + prompt_text = message + + # Get input + user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None) + + # Try to parse as float + try: + value = float(user_input) + except ValueError: + console.print(f"[red]✗ Invalid input. Please enter a valid number.[/red]") + console.print() + continue + + # Validate range + if min_val is not None and value < min_val: + console.print(f"[red]✗ Value must be at least {min_val}[/red]") + console.print() + continue + + if max_val is not None and value > max_val: + console.print(f"[red]✗ Value must be at most {max_val}[/red]") + console.print() + continue + + # All validations passed + return value + + +def prompt_choice_with_retry( + message: str, + choices: List[str], + default: Optional[str] = None, +) -> str: + """Prompt for choice input with validation and retry. + + Args: + message: Prompt message + choices: List of valid choices + default: Default choice (optional) + + Returns: + Selected choice + """ + while True: + # Get input + user_input = Prompt.ask(message, default=default) + + # Validate choice + if user_input not in choices: + console.print(f"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]") + console.print() + continue + + return user_input + + +def prompt_int_list_with_retry( + message: str, + default: Optional[str] = None, + min_val: Optional[int] = None, + max_val: Optional[int] = None, + validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None, +) -> List[int]: + """Prompt for comma-separated integer list with validation and retry. + + Args: + message: Prompt message + default: Default value as string (e.g., "0,1,2,3") + min_val: Minimum allowed value for each integer (optional) + max_val: Maximum allowed value for each integer (optional) + validator: Custom validation function that returns (is_valid, error_message) (optional) + + Returns: + List of validated integers + """ + while True: + # Get input + user_input = Prompt.ask(message, default=default) + + # Clean input: support Chinese comma and spaces + user_input_cleaned = user_input.replace(",", ",").replace(" ", "") + + # Try to parse as integers + try: + values = [int(x.strip()) for x in user_input_cleaned.split(",") if x.strip()] + except ValueError: + console.print(f"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]") + console.print() + continue + + # Validate each value's range + invalid_values = [] + for value in values: + if min_val is not None and value < min_val: + invalid_values.append(value) + elif max_val is not None and value > max_val: + invalid_values.append(value) + + if invalid_values: + if min_val is not None and max_val is not None: + console.print(f"[red]✗ Invalid value(s): {invalid_values}[/red]") + console.print(f"[yellow]Valid range: {min_val}-{max_val}[/yellow]") + elif min_val is not None: + console.print(f"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]") + elif max_val is not None: + console.print(f"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]") + console.print() + continue + + # Custom validation + if validator is not None: + is_valid, error_msg = validator(values) + if not is_valid: + console.print(f"[red]✗ {error_msg}[/red]") + console.print() + continue + + # All validations passed + return values diff --git a/kt-kernel/python/cli/utils/kv_cache_calculator.py b/kt-kernel/python/cli/utils/kv_cache_calculator.py new file mode 100644 index 0000000..e9aad1f --- /dev/null +++ b/kt-kernel/python/cli/utils/kv_cache_calculator.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +KV Cache Size Calculator for SGLang + +This script calculates the KV cache size in GB for a given model and number of tokens. +It follows the same logic as in sglang/srt/model_executor/model_runner.py +""" + +import os +import sys +import torch +from transformers import AutoConfig + +# Add sglang to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python")) + +from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim +from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool + + +def get_dtype_bytes(dtype_str: str) -> int: + """Get the number of bytes for a given dtype string.""" + dtype_map = { + "float32": 4, + "float16": 2, + "bfloat16": 2, + "float8_e4m3fn": 1, + "float8_e5m2": 1, + "auto": 2, # Usually defaults to bfloat16 + } + return dtype_map.get(dtype_str, 2) + + +def get_kv_size_gb( + model_path: str, + max_total_tokens: int, + tp: int = 1, + dtype: str = "auto", + verbose: bool = True, +) -> dict: + """ + Calculate the KV cache size in GB for a given model and number of tokens. + + Args: + model_path: Path to the model + max_total_tokens: Maximum number of tokens to cache + tp: Tensor parallelism size + dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.) + verbose: Whether to print detailed information + + Returns: + dict: Dictionary containing calculation details + """ + # Load model config + model_config = ModelConfig(model_path, dtype=dtype) + hf_config = model_config.hf_config + + # Determine dtype bytes + dtype_bytes = get_dtype_bytes(dtype) + if dtype == "auto": + # Auto dtype usually becomes bfloat16 + dtype_bytes = 2 + + # Number of layers + num_layers = model_config.num_attention_layers + + # Check if it's MLA (Multi-head Latent Attention) model + is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA" + + result = { + "model_path": model_path, + "max_total_tokens": max_total_tokens, + "tp": tp, + "dtype": dtype, + "dtype_bytes": dtype_bytes, + "num_layers": num_layers, + "is_mla": is_mla, + } + + if is_mla: + # MLA models (DeepSeek-V2/V3, MiniCPM3, etc.) + kv_lora_rank = model_config.kv_lora_rank + qk_rope_head_dim = model_config.qk_rope_head_dim + + # Calculate cell size (per token) + cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes + + result.update( + { + "kv_lora_rank": kv_lora_rank, + "qk_rope_head_dim": qk_rope_head_dim, + "cell_size_bytes": cell_size, + } + ) + + # Check if it's NSA (Native Sparse Attention) model + if is_deepseek_nsa(hf_config): + index_head_dim = get_nsa_index_head_dim(hf_config) + indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4 + indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype) + indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes + cell_size += indexer_cell_size + + result.update( + { + "is_nsa": True, + "index_head_dim": index_head_dim, + "indexer_cell_size_bytes": indexer_cell_size, + "total_cell_size_bytes": cell_size, + } + ) + else: + result["is_nsa"] = False + else: + # Standard MHA models + num_kv_heads = model_config.get_num_kv_heads(tp) + head_dim = model_config.head_dim + v_head_dim = model_config.v_head_dim + + # Calculate cell size (per token) + cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes + + result.update( + { + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "v_head_dim": v_head_dim, + "cell_size_bytes": cell_size, + } + ) + + # Calculate total KV cache size + total_size_bytes = max_total_tokens * cell_size + total_size_gb = total_size_bytes / (1024**3) + + # For MHA models with separate K and V buffers + if not is_mla: + k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes + v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes + k_size_gb = k_size_bytes / (1024**3) + v_size_gb = v_size_bytes / (1024**3) + + result.update( + { + "k_size_gb": k_size_gb, + "v_size_gb": v_size_gb, + } + ) + + result.update( + { + "total_size_bytes": total_size_bytes, + "total_size_gb": total_size_gb, + } + ) + + if verbose: + print(f"Model: {model_path}") + print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}") + print(f"Architecture: {'MLA' if is_mla else 'MHA'}") + print(f"Layers: {num_layers}") + + if is_mla: + print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}") + if result.get("is_nsa"): + print(f"NSA Index Head Dim: {index_head_dim}") + print( + f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})" + ) + else: + print(f"Cell size: {cell_size} bytes") + else: + print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}") + print(f"Cell size: {cell_size} bytes") + print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB") + + print(f"Total KV Cache Size: {total_size_gb:.2f} GB") + + return result + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Calculate KV cache size for a model") + parser.add_argument("model_path", help="Path to the model") + parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size") + parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)") + parser.add_argument("--quiet", action="store_true", help="Suppress verbose output") + + args = parser.parse_args() + + result = get_kv_size_gb( + args.model_path, + args.max_total_tokens, + tp=args.tp, + dtype=args.dtype, + verbose=not args.quiet, + ) + + if args.quiet: + print(f"{result['total_size_gb']:.2f}") + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/python/cli/utils/model_discovery.py b/kt-kernel/python/cli/utils/model_discovery.py new file mode 100644 index 0000000..82689f1 --- /dev/null +++ b/kt-kernel/python/cli/utils/model_discovery.py @@ -0,0 +1,250 @@ +""" +Model Discovery Utilities + +Shared functions for discovering and registering new models across different commands. +""" + +from typing import List, Optional, Tuple +from pathlib import Path +from rich.console import Console + +from kt_kernel.cli.utils.model_scanner import ( + discover_models, + scan_directory_for_models, + ScannedModel, +) +from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel + + +console = Console() + + +def discover_and_register_global( + min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = "en" +) -> Tuple[int, int, List[UserModel]]: + """ + Perform global model discovery and register new models. + + Args: + min_size_gb: Minimum model size in GB + max_depth: Maximum search depth + show_progress: Whether to show progress messages + lang: Language for messages ("en" or "zh") + + Returns: + Tuple of (total_found, new_found, registered_models) + """ + registry = UserModelRegistry() + + if show_progress: + if lang == "zh": + console.print("[dim]正在扫描系统中的模型权重,这可能需要30-60秒...[/dim]") + else: + console.print("[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]") + + # Global scan + all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth) + + # Filter out existing models + new_models = [] + for model in all_models: + if not registry.find_by_path(model.path): + new_models.append(model) + + # Register new models + registered = [] + for model in new_models: + user_model = _create_and_register_model(registry, model) + if user_model: + registered.append(user_model) + + return len(all_models), len(new_models), registered + + +def discover_and_register_path( + path: str, + min_size_gb: float = 2.0, + existing_paths: Optional[set] = None, + show_progress: bool = True, + lang: str = "en", +) -> Tuple[int, int, List[UserModel]]: + """ + Discover models in a specific path and register new ones. + + Args: + path: Directory path to scan + min_size_gb: Minimum model file size in GB + existing_paths: Set of already discovered paths in this session (optional) + show_progress: Whether to show progress messages + lang: Language for messages ("en" or "zh") + + Returns: + Tuple of (total_found, new_found, registered_models) + """ + registry = UserModelRegistry() + + if show_progress: + if lang == "zh": + console.print(f"[dim]正在扫描 {path}...[/dim]") + else: + console.print(f"[dim]Scanning {path}...[/dim]") + + # Scan directory + model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb) + + if not model_info: + return 0, 0, [] + + # Convert to ScannedModel and filter + new_models = [] + for dir_path, (format_type, size_bytes, file_count, files) in model_info.items(): + # Check if already in registry + if registry.find_by_path(dir_path): + continue + + # Check if already discovered in this session + if existing_paths and dir_path in existing_paths: + continue + + model = ScannedModel( + path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files + ) + new_models.append(model) + + # Register new models + registered = [] + for model in new_models: + user_model = _create_and_register_model(registry, model) + if user_model: + registered.append(user_model) + + return len(model_info), len(new_models), registered + + +def _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]: + """ + Create a UserModel from ScannedModel and register it. + + Handles name conflicts by suggesting a unique name (e.g., model-2, model-3). + Automatically detects repo_id from README.md YAML frontmatter. + Automatically detects and caches MoE information for safetensors models. + + Args: + registry: UserModelRegistry instance + scanned_model: ScannedModel to register + + Returns: + Registered UserModel or None if failed + """ + # Use suggest_name to get a unique name (adds -2, -3, etc. if needed) + unique_name = registry.suggest_name(scanned_model.folder_name) + + user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format) + + # Auto-detect repo_id from README.md (only YAML frontmatter) + try: + from kt_kernel.cli.utils.repo_detector import detect_repo_for_model + + repo_info = detect_repo_for_model(scanned_model.path) + if repo_info: + repo_id, repo_type = repo_info + user_model.repo_id = repo_id + user_model.repo_type = repo_type + except Exception: + # Silently continue if detection fails + pass + + # Auto-detect MoE information for safetensors models + if scanned_model.format == "safetensors": + try: + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model + + moe_result = analyze_moe_model(scanned_model.path, use_cache=True) + if moe_result and moe_result.get("is_moe"): + user_model.is_moe = True + user_model.moe_num_experts = moe_result.get("num_experts") + user_model.moe_num_experts_per_tok = moe_result.get("num_experts_per_tok") + else: + user_model.is_moe = False + except Exception: + # Silently continue if MoE detection fails + # is_moe will remain None + pass + + try: + registry.add_model(user_model) + return user_model + except Exception: + # Should not happen since we used suggest_name, but handle gracefully + return None + + +def format_discovery_summary( + total_found: int, + new_found: int, + registered: List[UserModel], + lang: str = "en", + show_models: bool = True, + max_show: int = 10, +) -> None: + """ + Print formatted discovery summary. + + Args: + total_found: Total models found + new_found: New models found + registered: List of registered UserModel objects + lang: Language ("en" or "zh") + show_models: Whether to show model list + max_show: Maximum models to show + """ + console.print() + + if new_found == 0: + if total_found > 0: + if lang == "zh": + console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,所有模型均已在列表中") + else: + console.print(f"[green]✓[/green] Scan complete: found {total_found} models, all already in the list") + else: + if lang == "zh": + console.print("[yellow]未找到模型[/yellow]") + else: + console.print("[yellow]No models found[/yellow]") + return + + # Show summary + if lang == "zh": + console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,其中 {new_found} 个为新模型") + else: + console.print(f"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new") + + # Show registered count + if len(registered) > 0: + if lang == "zh": + console.print(f"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表") + else: + console.print(f"[green]✓[/green] Successfully added {len(registered)} new models to list") + + # Show model list + if show_models and registered: + console.print() + if lang == "zh": + console.print(f"[dim]新发现的模型(前{max_show}个):[/dim]") + else: + console.print(f"[dim]Newly discovered models (first {max_show}):[/dim]") + + for i, model in enumerate(registered[:max_show], 1): + # Get size from registry or estimate + size_str = "?.? GB" + # Try to find the ScannedModel to get size + # For now just show name and path + console.print(f" {i}. {model.name} ({model.format})") + console.print(f" [dim]{model.path}[/dim]") + + if len(registered) > max_show: + remaining = len(registered) - max_show + if lang == "zh": + console.print(f" [dim]... 还有 {remaining} 个新模型[/dim]") + else: + console.print(f" [dim]... and {remaining} more new models[/dim]") diff --git a/kt-kernel/python/cli/utils/model_scanner.py b/kt-kernel/python/cli/utils/model_scanner.py new file mode 100644 index 0000000..707b293 --- /dev/null +++ b/kt-kernel/python/cli/utils/model_scanner.py @@ -0,0 +1,790 @@ +""" +Model Scanner + +Scans directories for model files (safetensors, gguf) and identifies models +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Set, Tuple, Dict +from collections import defaultdict +import os +import subprocess +import json + + +@dataclass +class ScannedModel: + """Temporary structure for scanned model information""" + + path: str # Absolute path to model directory + format: str # "safetensors" | "gguf" | "mixed" + size_bytes: int # Total size in bytes + file_count: int # Number of model files + files: List[str] # List of model file names + + @property + def size_gb(self) -> float: + """Get size in GB""" + return self.size_bytes / (1024**3) + + @property + def folder_name(self) -> str: + """Get the folder name (default model name)""" + return Path(self.path).name + + +class ModelScanner: + """Scanner for discovering models in directory trees""" + + def __init__(self, min_size_gb: float = 10.0): + """ + Initialize scanner + + Args: + min_size_gb: Minimum folder size in GB to be considered a model + """ + self.min_size_bytes = int(min_size_gb * 1024**3) + + def scan_directory( + self, base_path: Path, exclude_paths: Optional[Set[str]] = None + ) -> Tuple[List[ScannedModel], List[str]]: + """ + Scan directory tree for models + + Args: + base_path: Root directory to scan + exclude_paths: Set of absolute paths to exclude from results + + Returns: + Tuple of (valid_models, warnings) + - valid_models: List of ScannedModel instances + - warnings: List of warning messages + """ + if not base_path.exists(): + raise ValueError(f"Path does not exist: {base_path}") + + if not base_path.is_dir(): + raise ValueError(f"Path is not a directory: {base_path}") + + exclude_paths = exclude_paths or set() + results: List[ScannedModel] = [] + warnings: List[str] = [] + + # Walk the directory tree + for root, dirs, files in os.walk(base_path): + root_path = Path(root).resolve() + + # Skip if already registered + if str(root_path) in exclude_paths: + dirs[:] = [] # Don't descend into this directory + continue + + # Check for model files + safetensors_files = [f for f in files if f.endswith(".safetensors")] + gguf_files = [f for f in files if f.endswith(".gguf")] + + if not safetensors_files and not gguf_files: + continue # No model files in this directory + + # Calculate total size + model_files = safetensors_files + gguf_files + total_size = self._calculate_total_size(root_path, model_files) + + # Check if size meets minimum threshold + if total_size < self.min_size_bytes: + continue # Too small, but keep scanning subdirectories + + # Detect format + if safetensors_files and gguf_files: + # Mixed format - issue warning + warnings.append( + f"Mixed format detected in {root_path}: " + f"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. " + "Please separate into different folders and re-scan." + ) + dirs[:] = [] # Don't descend into mixed format directories + continue + + # Determine format + format_type = "safetensors" if safetensors_files else "gguf" + + # Create scanned model + scanned = ScannedModel( + path=str(root_path), + format=format_type, + size_bytes=total_size, + file_count=len(model_files), + files=model_files, + ) + + results.append(scanned) + + # Continue scanning subdirectories - they might also contain models + # Each subdirectory will be independently checked for size >= 10GB + + return results, warnings + + def scan_single_path(self, path: Path) -> Optional[ScannedModel]: + """ + Scan a single path for model files + + Args: + path: Path to scan + + Returns: + ScannedModel instance or None if not a valid model + """ + if not path.exists() or not path.is_dir(): + return None + + # Find model files + safetensors_files = list(path.glob("*.safetensors")) + gguf_files = list(path.glob("*.gguf")) + + if not safetensors_files and not gguf_files: + return None + + # Check for mixed format + if safetensors_files and gguf_files: + raise ValueError( + f"Mixed format detected: {len(safetensors_files)} safetensors + " + f"{len(gguf_files)} gguf files. Please use a single format." + ) + + # Calculate size + model_files = [f.name for f in safetensors_files + gguf_files] + total_size = self._calculate_total_size(path, model_files) + + # Determine format + format_type = "safetensors" if safetensors_files else "gguf" + + return ScannedModel( + path=str(path.resolve()), + format=format_type, + size_bytes=total_size, + file_count=len(model_files), + files=model_files, + ) + + def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int: + """ + Calculate total size of specified files in directory + + Args: + directory: Directory containing the files + filenames: List of filenames to sum + + Returns: + Total size in bytes + """ + total = 0 + for filename in filenames: + file_path = directory / filename + if file_path.exists(): + try: + total += file_path.stat().st_size + except OSError: + # File might be inaccessible, skip it + pass + return total + + +# Convenience functions + + +def scan_directory( + base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None +) -> Tuple[List[ScannedModel], List[str]]: + """ + Convenience function to scan a directory + + Args: + base_path: Root directory to scan + min_size_gb: Minimum folder size in GB + exclude_paths: Set of paths to exclude + + Returns: + Tuple of (models, warnings) + """ + scanner = ModelScanner(min_size_gb=min_size_gb) + return scanner.scan_directory(base_path, exclude_paths) + + +def scan_single_path(path: Path) -> Optional[ScannedModel]: + """ + Convenience function to scan a single path + + Args: + path: Path to scan + + Returns: + ScannedModel or None + """ + scanner = ModelScanner() + return scanner.scan_single_path(path) + + +def format_size(size_bytes: int) -> str: + """ + Format size in bytes to human-readable string + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g., "42.3 GB") + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f} PB" + + +# ===== Fast Scanning with Find Command and Tree-based Root Detection ===== + + +def find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]: + """ + Use find command to quickly locate files + + Args: + mount_point: Starting directory + pattern: File pattern (e.g., "config.json", "*.gguf") + max_depth: Maximum directory depth (default: 6) + timeout: Command timeout in seconds + + Returns: + List of absolute file paths + """ + try: + # Use shell=True to redirect stderr to /dev/null, ignoring permission errors + result = subprocess.run( + f'find "{mount_point}" -maxdepth {max_depth} -name "{pattern}" -type f 2>/dev/null', + capture_output=True, + text=True, + timeout=timeout, + shell=True, + ) + + # Return results even if returncode is non-zero (due to permission errors) + # As long as we got some output + if result.stdout: + return [line.strip() for line in result.stdout.strip().split("\n") if line.strip()] + return [] + except (subprocess.TimeoutExpired, FileNotFoundError): + return [] + + +def is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]: + """ + Check if a directory is a valid model directory + + Args: + directory: Path to check + min_size_gb: Minimum size in GB + + Returns: + (is_valid, model_type) where model_type is "safetensors", "gguf", or None + """ + if not directory.exists() or not directory.is_dir(): + return False, None + + has_config = (directory / "config.json").exists() + safetensors_files = list(directory.glob("*.safetensors")) + gguf_files = list(directory.glob("*.gguf")) + + # Determine model type + model_type = None + if (has_config and safetensors_files) or safetensors_files: + model_type = "safetensors" + elif gguf_files: + model_type = "gguf" + else: + return False, None + + # Check size - only count model files (fast!) + total_size = 0 + if model_type == "safetensors": + for f in safetensors_files: + try: + total_size += f.stat().st_size + except OSError: + pass + else: # gguf + for f in gguf_files: + try: + total_size += f.stat().st_size + except OSError: + pass + + size_gb = total_size / (1024**3) + if size_gb < min_size_gb: + return False, None + + return True, model_type + + +def scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]: + """ + Fast scan for all model paths using find command + + Args: + mount_points: List of mount points to scan + min_size_gb: Minimum model size in GB + max_depth: Maximum search depth (default: 6) + + Returns: + List of valid model directory paths + """ + model_paths = set() + + for mount in mount_points: + if not os.path.exists(mount): + continue + + # Find all config.json files + config_files = find_files_fast(mount, "config.json", max_depth=max_depth) + for config_path in config_files: + model_dir = Path(config_path).parent + is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb) + if is_valid: + model_paths.add(str(model_dir.resolve())) + + # Find all *.gguf files + gguf_files = find_files_fast(mount, "*.gguf", max_depth=max_depth) + for gguf_path in gguf_files: + model_dir = Path(gguf_path).parent + is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb) + if is_valid: + model_paths.add(str(model_dir.resolve())) + + return sorted(model_paths) + + +def get_root_subdirs() -> List[str]: + """ + Get subdirectories of / that are worth scanning + + Filters out system paths only + + Returns: + List of directories to scan + """ + # System paths to exclude + excluded = { + "dev", + "proc", + "sys", + "run", + "boot", + "tmp", + "usr", + "lib", + "lib64", + "bin", + "sbin", + "etc", + "opt", + "var", + "snap", + } + + scan_dirs = [] + + try: + for entry in os.scandir("/"): + if not entry.is_dir(): + continue + + # Skip excluded paths + if entry.name in excluded: + continue + + scan_dirs.append(entry.path) + + except PermissionError: + pass + + return sorted(scan_dirs) + + +def scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]: + """ + Scan a directory for models using find command with size filter + + Uses find -size +2G to only locate large model files (>=2GB) + + Args: + directory: Directory to scan + min_file_size_gb: Minimum individual file size in GB (default: 2.0) + + Returns: + Dict mapping model_path -> (model_type, size_bytes, file_count, files) + """ + model_info = {} + + # Convert GB to find's format (e.g., 2GB = +2G) + if min_file_size_gb >= 1.0: + size_filter = f"+{int(min_file_size_gb)}G" + else: + size_mb = int(min_file_size_gb * 1024) + size_filter = f"+{size_mb}M" + + # 1. Find *.gguf files >= 2GB + gguf_cmd = f'find "{directory}" -name "*.gguf" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null' + result = subprocess.run(gguf_cmd, shell=True, capture_output=True, text=True, timeout=120) + + # Group by directory + gguf_dirs = defaultdict(list) + for line in result.stdout.strip().split("\n"): + if not line: + continue + parts = line.split("\t") + if len(parts) != 2: + continue + file_path, size_str = parts + file_path_obj = Path(file_path) + dir_path = str(file_path_obj.parent) + gguf_dirs[dir_path].append((file_path_obj.name, int(size_str))) + + # Add all gguf directories + for dir_path, files in gguf_dirs.items(): + total_size = sum(size for _, size in files) + model_info[dir_path] = ("gguf", total_size, len(files), [name for name, _ in files]) + + # 2. Find *.safetensors files >= 2GB + safetensors_cmd = ( + f'find "{directory}" -name "*.safetensors" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null' + ) + result = subprocess.run(safetensors_cmd, shell=True, capture_output=True, text=True, timeout=120) + + # Group by directory + safetensors_dirs = defaultdict(list) + for line in result.stdout.strip().split("\n"): + if not line: + continue + parts = line.split("\t") + if len(parts) != 2: + continue + file_path, size_str = parts + file_path_obj = Path(file_path) + dir_path = str(file_path_obj.parent) + safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str))) + + # 3. Check each safetensors directory for config.json + for dir_path, files in safetensors_dirs.items(): + if os.path.exists(os.path.join(dir_path, "config.json")): + total_size = sum(size for _, size in files) + model_info[dir_path] = ("safetensors", total_size, len(files), [name for name, _ in files]) + + return model_info + + +def scan_all_models_with_info( + mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6 +) -> Dict[str, tuple]: + """ + Fast scan with parallel directory scanning + + Strategy: + 1. Use provided directories or auto-detect root subdirectories + 2. Scan each directory in parallel (one thread per directory) + 3. Use find -size +2G to find large model files (>=2GB) + + Args: + mount_points: Specific directories to scan, or None to auto-detect from / subdirs + min_size_gb: Not used anymore (kept for API compatibility) + max_depth: Not used anymore (kept for API compatibility) + + Returns: + Dict mapping model_path -> (model_type, size_bytes, file_count, files) + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Get directories to scan + if mount_points is None: + # Get root subdirectories (exclude system paths) + scan_dirs = get_root_subdirs() + else: + scan_dirs = mount_points + + if not scan_dirs: + return {} + + model_info = {} + + # Scan each directory in parallel (max 8 concurrent) + # Use 2GB threshold to find model files + with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor: + futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs} + + for future in as_completed(futures): + try: + dir_results = future.result() + model_info.update(dir_results) + except Exception as e: + # Skip directories with errors + pass + + return model_info + + +def find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]: + """ + Find optimal root paths from model paths using tree-based algorithm + + Algorithm: + 1. Build path tree with all intermediate paths + 2. DFS to calculate f(x) = subtree sum (number of models in subtree) + 3. Find roots where f(parent) = f(x) > max(f(children)) + + Args: + model_paths: List of model directory paths + + Returns: + (root_paths, subtree_sizes) where: + - root_paths: List of inferred root directories + - subtree_sizes: Dict mapping each root to number of models + """ + if not model_paths: + return [], {} + + # 1. Build path set (including all intermediate paths) + all_paths = set() + model_set = set(model_paths) + + for model_path in model_paths: + path = Path(model_path) + for i in range(1, len(path.parts) + 1): + all_paths.add(str(Path(*path.parts[:i]))) + + # 2. Build parent-child relationships + children_map = defaultdict(list) + for path in all_paths: + path_obj = Path(path) + if len(path_obj.parts) > 1: + parent = str(path_obj.parent) + if parent in all_paths: + children_map[parent].append(path) + + # 3. DFS to calculate f(x) and max_child_f(x) + f = {} # path -> subtree sum + max_child_f = {} # path -> max(f(children)) + visited = set() + + def dfs(path: str) -> int: + if path in visited: + return f[path] + visited.add(path) + + # Current node weight (1 if it's a model path, 0 otherwise) + weight = 1 if path in model_set else 0 + + # Recursively calculate children + children = children_map.get(path, []) + if not children: + # Leaf node + f[path] = weight + max_child_f[path] = 0 + return weight + + # Calculate f values for all children + children_f_values = [dfs(child) for child in children] + + # Calculate f(x) and max_child_f(x) + f[path] = weight + sum(children_f_values) + max_child_f[path] = max(children_f_values) if children_f_values else 0 + + return f[path] + + # Find top-level nodes (no parent in all_paths) + top_nodes = [] + for path in all_paths: + parent = str(Path(path).parent) + if parent not in all_paths or parent == path: + top_nodes.append(path) + + # Execute DFS from all top nodes + for top in top_nodes: + dfs(top) + + # 4. Find root nodes: f(parent) = f(x) >= max(f(children)) + # Note: Use >= instead of > to handle the case where a directory contains only one model + candidate_roots = [] + for path in all_paths: + # Skip model paths themselves (leaf nodes in model tree) + if path in model_set: + continue + + parent = str(Path(path).parent) + + # Check condition: f(parent) = f(x) and f(x) >= max(f(children)) + if parent in f and f.get(parent, 0) == f.get(path, 0): + if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0: + candidate_roots.append(path) + + # 5. Remove redundant roots (prefer deeper paths) + # If a root is an ancestor of another root with the same f value, remove it + roots = [] + candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts)) + + for root in candidate_roots_sorted: + # Check if this root is a parent of any already selected root + is_redundant = False + for selected in roots: + if selected.startswith(root + "/"): + # selected is a child of root + # Only keep root if it has more models (shouldn't happen by algorithm) + if f.get(root, 0) == f.get(selected, 0): + is_redundant = True + break + + if not is_redundant: + # Also filter out very shallow paths (< 3 levels) + if len(Path(root).parts) >= 3: + roots.append(root) + + # Build subtree sizes for roots + subtree_sizes = {root: f.get(root, 0) for root in roots} + + return sorted(roots), subtree_sizes + + +@dataclass +class ModelRootInfo: + """Information about a detected model root path""" + + path: str + model_count: int + models: List[ScannedModel] + + +def discover_models( + mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6 +) -> List[ScannedModel]: + """ + Discover all model directories on the system + + Fast scan using find command to locate all models that meet the criteria + + Args: + mount_points: List of mount points to scan (None = auto-detect) + min_size_gb: Minimum model size in GB (default: 10.0) + max_depth: Maximum search depth (default: 6) + + Returns: + List of ScannedModel sorted by path + """ + # Auto-detect mount points if not provided + if mount_points is None: + mount_points = _get_mount_points() + + # Fast scan with cached info (only scan once!) + model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth) + + if not model_info: + return [] + + # Convert to ScannedModel objects + results = [] + for model_path, (model_type, total_size, file_count, files) in model_info.items(): + results.append( + ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files) + ) + + # Sort by path + results.sort(key=lambda m: m.path) + return results + + +def _get_mount_points() -> List[str]: + """ + Get all valid mount points from /proc/mounts, filtering out system paths + + Returns: + List of mount point paths suitable for model storage + (excludes root "/" to avoid scanning entire filesystem) + """ + mount_points = set() + + # System paths to exclude (unlikely to contain model files) + excluded_paths = [ + "/snap/", + "/proc/", + "/sys/", + "/run/", + "/boot", + "/dev/", + "/usr", + "/lib", + "/lib64", + "/bin", + "/sbin", + "/etc", + "/opt", + "/var", + "/tmp", + ] + + try: + with open("/proc/mounts", "r") as f: + for line in f: + parts = line.split() + if len(parts) < 3: + continue + + device, mount_point, fs_type = parts[0], parts[1], parts[2] + + # Filter out pseudo filesystems + pseudo_fs = { + "proc", + "sysfs", + "devpts", + "tmpfs", + "devtmpfs", + "cgroup", + "cgroup2", + "pstore", + "bpf", + "tracefs", + "debugfs", + "hugetlbfs", + "mqueue", + "configfs", + "securityfs", + "fuse.gvfsd-fuse", + "fusectl", + "squashfs", + "overlay", # snap packages + } + + if fs_type in pseudo_fs: + continue + + # Skip root directory (too large to scan) + if mount_point == "/": + continue + + # Filter out system paths + if any(mount_point.startswith(x) for x in excluded_paths): + continue + + # Only include if it exists and is readable + if os.path.exists(mount_point) and os.access(mount_point, os.R_OK): + mount_points.add(mount_point) + + # If no mount points found, add common data directories + if not mount_points: + # Add /home if it exists and is not already a separate mount point + common_paths = ["/home", "/data", "/mnt"] + for path in common_paths: + if os.path.exists(path) and os.access(path, os.R_OK): + mount_points.add(path) + + except (FileNotFoundError, PermissionError): + # Fallback to common paths + mount_points = {"/home", "/mnt", "/data"} + + return sorted(mount_points) diff --git a/kt-kernel/python/cli/utils/model_table_builder.py b/kt-kernel/python/cli/utils/model_table_builder.py new file mode 100644 index 0000000..b5be5ba --- /dev/null +++ b/kt-kernel/python/cli/utils/model_table_builder.py @@ -0,0 +1,254 @@ +""" +Shared model table builders for consistent UI across commands. + +Provides reusable table construction functions for displaying models +in kt model list, kt quant, kt run, etc. +""" + +from typing import List, Optional, Tuple +from pathlib import Path +from rich.table import Table +from rich.console import Console +import json + + +def format_model_size(model_path: Path, format_type: str) -> str: + """Calculate and format model size.""" + from kt_kernel.cli.utils.model_scanner import format_size + + try: + if format_type == "safetensors": + files = list(model_path.glob("*.safetensors")) + elif format_type == "gguf": + files = list(model_path.glob("*.gguf")) + else: + return "[dim]-[/dim]" + + total_size = sum(f.stat().st_size for f in files if f.exists()) + return format_size(total_size) + except Exception: + return "[dim]-[/dim]" + + +def format_repo_info(model) -> str: + """Format repository information.""" + if model.repo_id: + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + return f"{repo_abbr}:{model.repo_id}" + return "[dim]-[/dim]" + + +def format_sha256_status(model, status_map: dict) -> str: + """Format SHA256 verification status.""" + return status_map.get(model.sha256_status or "not_checked", "[dim]?[/dim]") + + +def build_moe_gpu_table( + models: List, status_map: dict, show_index: bool = True, start_index: int = 1 +) -> Tuple[Table, List]: + """ + Build MoE GPU models table. + + Args: + models: List of MoE GPU model objects + status_map: SHA256_STATUS_MAP for formatting status + show_index: Whether to show # column for selection (default: True) + start_index: Starting index number + + Returns: + Tuple of (Table object, list of models in display order) + """ + table = Table(show_header=True, header_style="bold", show_lines=False) + + if show_index: + table.add_column("#", justify="right", style="cyan", no_wrap=True) + + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Exps", justify="center", style="yellow") + table.add_column("Act", justify="center", style="green") + table.add_column("Repository", style="dim", overflow="fold") + table.add_column("SHA256", justify="center") + + displayed_models = [] + + for i, model in enumerate(models, start_index): + displayed_models.append(model) + + # Calculate size + size_str = format_model_size(Path(model.path), "safetensors") + + # MoE info + num_experts = str(model.moe_num_experts) if model.moe_num_experts else "[dim]-[/dim]" + num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else "[dim]-[/dim]" + + # Repository and SHA256 + repo_str = format_repo_info(model) + sha256_str = format_sha256_status(model, status_map) + + row = [] + if show_index: + row.append(str(i)) + + row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str]) + + table.add_row(*row) + + return table, displayed_models + + +def build_amx_table( + models: List, + status_map: dict = None, # Kept for API compatibility but not used + show_index: bool = True, + start_index: int = 1, + show_linked_gpus: bool = False, + gpu_models: Optional[List] = None, +) -> Tuple[Table, List]: + """ + Build AMX models table. + + Note: AMX models are locally quantized, so no SHA256 verification column. + + Args: + models: List of AMX model objects + status_map: (Unused - kept for API compatibility) + show_index: Whether to show # column for selection (default: True) + start_index: Starting index number + show_linked_gpus: Whether to show sub-rows for linked GPU models + gpu_models: List of GPU models (required if show_linked_gpus=True) + + Returns: + Tuple of (Table object, list of models in display order) + """ + table = Table(show_header=True, header_style="bold", show_lines=False) + + if show_index: + table.add_column("#", justify="right", style="cyan", no_wrap=True) + + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Method", justify="center", style="yellow") + table.add_column("NUMA", justify="center", style="green") + table.add_column("Source", style="dim", overflow="fold") + + # Build reverse map if needed + amx_used_by_gpu = {} + if show_linked_gpus and gpu_models: + for model in models: + if model.gpu_model_ids: + gpu_names = [] + for gpu_id in model.gpu_model_ids: + for gpu_model in gpu_models: + if gpu_model.id == gpu_id: + gpu_names.append(gpu_model.name) + break + if gpu_names: + amx_used_by_gpu[model.id] = gpu_names + + displayed_models = [] + + for i, model in enumerate(models, start_index): + displayed_models.append(model) + + # Calculate size + size_str = format_model_size(Path(model.path), "safetensors") + + # Read metadata from config.json or UserModel fields + method_from_config = None + numa_from_config = None + try: + config_path = Path(model.path) / "config.json" + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + amx_quant = config.get("amx_quantization", {}) + if amx_quant.get("converted"): + method_from_config = amx_quant.get("method") + numa_from_config = amx_quant.get("numa_count") + except Exception: + pass + + # Priority: UserModel fields > config.json > ? + method_display = ( + model.amx_quant_method.upper() + if model.amx_quant_method + else method_from_config.upper() if method_from_config else "[dim]?[/dim]" + ) + numa_display = ( + str(model.amx_numa_nodes) + if model.amx_numa_nodes + else str(numa_from_config) if numa_from_config else "[dim]?[/dim]" + ) + source_display = model.amx_source_model or "[dim]-[/dim]" + + row = [] + if show_index: + row.append(str(i)) + + row.extend([model.name, model.path, size_str, method_display, numa_display, source_display]) + + table.add_row(*row) + + # Add sub-row showing linked GPUs + if show_linked_gpus and model.id in amx_used_by_gpu: + gpu_list = amx_used_by_gpu[model.id] + gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list]) + sub_row = [] + if show_index: + sub_row.append("") + sub_row.extend([f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""]) + table.add_row(*sub_row, style="dim") + + return table, displayed_models + + +def build_gguf_table( + models: List, status_map: dict, show_index: bool = True, start_index: int = 1 +) -> Tuple[Table, List]: + """ + Build GGUF models table. + + Args: + models: List of GGUF model objects + status_map: SHA256_STATUS_MAP for formatting status + show_index: Whether to show # column for selection (default: True) + start_index: Starting index number + + Returns: + Tuple of (Table object, list of models in display order) + """ + table = Table(show_header=True, header_style="bold", show_lines=False) + + if show_index: + table.add_column("#", justify="right", style="cyan", no_wrap=True) + + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Repository", style="dim", overflow="fold") + table.add_column("SHA256", justify="center") + + displayed_models = [] + + for i, model in enumerate(models, start_index): + displayed_models.append(model) + + # Calculate size + size_str = format_model_size(Path(model.path), "gguf") + + # Repository and SHA256 + repo_str = format_repo_info(model) + sha256_str = format_sha256_status(model, status_map) + + row = [] + if show_index: + row.append(str(i)) + + row.extend([model.name, model.path, size_str, repo_str, sha256_str]) + + table.add_row(*row) + + return table, displayed_models diff --git a/kt-kernel/python/cli/utils/model_verifier.py b/kt-kernel/python/cli/utils/model_verifier.py new file mode 100644 index 0000000..175cd70 --- /dev/null +++ b/kt-kernel/python/cli/utils/model_verifier.py @@ -0,0 +1,918 @@ +""" +Model Verifier + +SHA256 verification for model integrity +""" + +import hashlib +import requests +import os +from pathlib import Path +from typing import Dict, Any, Literal, Tuple +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]: + """ + Compute SHA256 for a single file (worker function for multiprocessing). + + Args: + file_path: Path to the file + + Returns: + Tuple of (filename, sha256_hash, file_size_mb) + """ + sha256_hash = hashlib.sha256() + file_size_mb = file_path.stat().st_size / (1024 * 1024) + + # Read file in chunks to handle large files + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks + sha256_hash.update(byte_block) + + return file_path.name, sha256_hash.hexdigest(), file_size_mb + + +def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]: + """ + Check if HuggingFace is accessible. + + Args: + timeout: Connection timeout in seconds + + Returns: + Tuple of (is_accessible, message) + """ + test_url = "https://huggingface.co" + + try: + response = requests.head(test_url, timeout=timeout, allow_redirects=True) + if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible" + return True, "HuggingFace is accessible" + except requests.exceptions.Timeout: + return False, f"Connection to {test_url} timed out" + except requests.exceptions.ConnectionError: + return False, f"Cannot connect to {test_url}" + except requests.exceptions.RequestException as e: + return False, f"Connection error: {str(e)}" + + return False, "Unknown connection error" + + +def verify_model_integrity( + repo_type: Literal["huggingface", "modelscope"], + repo_id: str, + local_dir: Path, + progress_callback=None, +) -> Dict[str, Any]: + """ + Verify local model integrity against remote repository SHA256 hashes. + + Verifies all important files: + - *.safetensors (weights) + - *.json (config files) + - *.py (custom model code) + + Args: + repo_type: Type of repository ("huggingface" or "modelscope") + repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3") + local_dir: Local directory containing model files + progress_callback: Optional callback function(message: str) for progress updates + + Returns: + Dictionary with verification results: + { + "status": "passed" | "failed" | "error", + "files_checked": int, + "files_passed": int, + "files_failed": [list of filenames], + "error_message": str (optional) + } + """ + + def report_progress(msg: str): + """Helper to report progress""" + if progress_callback: + progress_callback(msg) + + try: + # Convert repo_type to platform format + platform = "hf" if repo_type == "huggingface" else "ms" + + # 1. Fetch official SHA256 hashes from remote + report_progress("Fetching official SHA256 hashes from remote repository...") + official_hashes = fetch_model_sha256(repo_id, platform) + report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote") + + if not official_hashes: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"No verifiable files found in remote repository: {repo_id}", + } + + # 2. Calculate local SHA256 hashes with progress + report_progress(f"Calculating SHA256 for local files...") + + # Get all local files matching the patterns + local_files = [] + for pattern in ["*.safetensors", "*.json", "*.py"]: + local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()]) + + if not local_files: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"No verifiable files found in local directory: {local_dir}", + } + + # Calculate hashes for all files + local_hashes = calculate_local_sha256( + local_dir, + file_pattern="*.safetensors", # Unused when files_list is provided + progress_callback=report_progress, + files_list=local_files, + ) + report_progress(f"✓ Calculated {len(local_hashes)} local file hashes") + + # 3. Compare hashes with progress + report_progress(f"Comparing {len(official_hashes)} files...") + files_failed = [] + files_missing = [] + files_passed = 0 + + for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1): + # Handle potential path separators in filename + file_basename = Path(filename).name + + # Try to find the file in local hashes + local_hash = None + for local_file, local_hash_value in local_hashes.items(): + if Path(local_file).name == file_basename: + local_hash = local_hash_value + break + + if local_hash is None: + files_missing.append(filename) + report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING") + elif local_hash.lower() != official_hash.lower(): + files_failed.append(f"{filename} (hash mismatch)") + report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH") + else: + files_passed += 1 + report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}") + + # 4. Return results + total_checked = len(official_hashes) + + if files_failed or files_missing: + all_failed = files_failed + [f"{f} (missing)" for f in files_missing] + return { + "status": "failed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": all_failed, + "error_message": f"{len(all_failed)} file(s) failed verification", + } + else: + return { + "status": "passed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": [], + } + + except ImportError as e: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope", + "is_network_error": False, + } + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.RequestException, + ) as e: + # Network-related errors - suggest mirror + error_msg = f"Network error: {str(e)}" + if repo_type == "huggingface": + error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com" + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": error_msg, + "is_network_error": True, + } + except Exception as e: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"Verification failed: {str(e)}", + "is_network_error": False, + } + + +def calculate_local_sha256( + local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None +) -> Dict[str, str]: + """ + Calculate SHA256 hashes for files in a directory using parallel processing. + + Args: + local_dir: Directory to scan + file_pattern: Glob pattern for files to hash (ignored if files_list is provided) + progress_callback: Optional callback function(message: str) for progress updates + files_list: Optional pre-filtered list of files to hash (overrides file_pattern) + + Returns: + Dictionary mapping filename to SHA256 hash + """ + result = {} + + if not local_dir.exists(): + return result + + # Get all files first to report total + if files_list is not None: + files_to_hash = files_list + else: + files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()] + total_files = len(files_to_hash) + + if total_files == 0: + return result + + # Use min(16, total_files) workers to avoid over-spawning processes + max_workers = min(16, total_files) + + if progress_callback: + progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation") + + # Use ProcessPoolExecutor for CPU-intensive SHA256 computation + completed_count = 0 + with ProcessPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash} + + # Process results as they complete + for future in as_completed(future_to_file): + completed_count += 1 + try: + filename, sha256_hash, file_size_mb = future.result() + result[filename] = sha256_hash + + if progress_callback: + progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)") + + except Exception as e: + file_path = future_to_file[future] + if progress_callback: + progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}") + + return result + + +def fetch_model_sha256( + repo_id: str, + platform: Literal["hf", "ms"], + revision: str | None = None, + use_mirror: bool = False, + timeout: int | None = None, +) -> dict[str, str]: + """ + 获取模型仓库中所有重要文件的 sha256 哈希值。 + + 包括: + - *.safetensors (权重文件) + - *.json (配置文件:config.json, tokenizer_config.json 等) + - *.py (自定义模型代码:modeling.py, configuration.py 等) + + Args: + repo_id: 仓库 ID,例如 "Qwen/Qwen3-30B-A3B" + platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope) + revision: 版本/分支,默认 HuggingFace 为 "main",ModelScope 为 "master" + use_mirror: 是否使用镜像(仅对 HuggingFace 有效) + timeout: 网络请求超时时间(秒),None 表示不设置超时 + + Returns: + dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."} + """ + if platform == "hf": + # 先尝试直连,失败后自动使用镜像 + try: + if use_mirror: + return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout) + else: + return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout) + except Exception as e: + # 如果不是镜像模式且失败了,自动重试使用镜像 + if not use_mirror: + return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout) + else: + raise e + elif platform == "ms": + return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout) + else: + raise ValueError(f"不支持的平台: {platform},请使用 'hf' 或 'ms'") + + +def _fetch_from_huggingface( + repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None +) -> dict[str, str]: + """从 HuggingFace 获取所有重要文件的 sha256 + + Args: + repo_id: 仓库 ID + revision: 版本/分支 + use_mirror: 是否使用镜像(hf-mirror.com) + timeout: 网络请求超时时间(秒),None 表示不设置超时 + """ + import os + import socket + + # 如果需要使用镜像,设置环境变量 + original_endpoint = os.environ.get("HF_ENDPOINT") + if use_mirror and not original_endpoint: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + + # Set socket timeout if specified + original_timeout = socket.getdefaulttimeout() + if timeout is not None: + socket.setdefaulttimeout(timeout) + + from huggingface_hub import HfApi, list_repo_files + + try: + api = HfApi() + all_files = list_repo_files(repo_id=repo_id, revision=revision) + + # 筛选重要文件:*.safetensors, *.json, *.py + important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))] + + if not important_files: + return {} + + paths_info = api.get_paths_info( + repo_id=repo_id, + paths=important_files, + revision=revision, + ) + + result = {} + for file_info in paths_info: + if hasattr(file_info, "lfs") and file_info.lfs is not None: + sha256 = file_info.lfs.sha256 + else: + sha256 = getattr(file_info, "blob_id", None) + result[file_info.path] = sha256 + + return result + finally: + # 恢复原始 socket timeout + socket.setdefaulttimeout(original_timeout) + + # 恢复原始环境变量 + if use_mirror and not original_endpoint: + os.environ.pop("HF_ENDPOINT", None) + elif original_endpoint: + os.environ["HF_ENDPOINT"] = original_endpoint + + +def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]: + """从 ModelScope 获取所有重要文件的 sha256 + + Args: + repo_id: 仓库 ID + revision: 版本/分支 + timeout: 网络请求超时时间(秒),None 表示不设置超时 + """ + import socket + from modelscope.hub.api import HubApi + + # Set socket timeout if specified + original_timeout = socket.getdefaulttimeout() + if timeout is not None: + socket.setdefaulttimeout(timeout) + + try: + api = HubApi() + files_info = api.get_model_files(model_id=repo_id, revision=revision) + + result = {} + for file_info in files_info: + filename = file_info.get("Name", file_info.get("Path", "")) + # 筛选重要文件:*.safetensors, *.json, *.py + if filename.endswith((".safetensors", ".json", ".py")): + sha256 = file_info.get("Sha256", file_info.get("sha256", None)) + result[filename] = sha256 + + return result + finally: + # 恢复原始 socket timeout + socket.setdefaulttimeout(original_timeout) + + +def verify_model_integrity_with_progress( + repo_type: Literal["huggingface", "modelscope"], + repo_id: str, + local_dir: Path, + progress_callback=None, + verbose: bool = False, + use_mirror: bool = False, + files_to_verify: list[str] | None = None, + timeout: int | None = None, +) -> Dict[str, Any]: + """ + Verify model integrity with enhanced progress reporting for Rich Progress bars. + + This is a wrapper around verify_model_integrity() that provides more detailed + progress information suitable for progress bar display. + + The progress_callback receives: + - (message: str, total: int, current: int) for countable operations + - (message: str) for status updates + + Args: + repo_type: Repository type ("huggingface" or "modelscope") + repo_id: Repository ID + local_dir: Local directory path + progress_callback: Optional callback for progress updates + verbose: If True, output detailed SHA256 comparison for each file + use_mirror: If True, use HuggingFace mirror (hf-mirror.com) + files_to_verify: Optional list of specific files to verify (for re-verification) + timeout: Network request timeout in seconds (None = no timeout) + """ + + def report_progress(msg: str, total=None, current=None): + """Enhanced progress reporter""" + if progress_callback: + progress_callback(msg, total, current) + + try: + platform = "hf" if repo_type == "huggingface" else "ms" + + # 1. Fetch official SHA256 hashes + if files_to_verify: + report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...") + elif use_mirror and platform == "hf": + report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...") + else: + report_progress("Fetching official SHA256 hashes from remote repository...") + + official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout) + + # Filter to only requested files if specified + if files_to_verify: + # Extract clean filenames from files_to_verify (remove markers like "(missing)") + clean_filenames = set() + for f in files_to_verify: + clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip() + # Ensure we only use the filename, not full path + clean_filenames.add(Path(clean_f).name) + + # Filter official_hashes to only include requested files + # Compare using basename since official_hashes keys might have paths + official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames} + + report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote") + + if not official_hashes: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"No safetensors files found in remote repository: {repo_id}", + } + + # 2. Calculate local SHA256 hashes + local_dir_path = Path(local_dir) + + # Only hash the files we need to verify + if files_to_verify: + # Extract clean filenames (without markers) + clean_filenames = set() + for f in files_to_verify: + clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip() + # Ensure we only use the filename, not full path + clean_filenames.add(Path(clean_f).name) + + # Only hash files that match the clean filenames + files_to_hash = [ + f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames + ] + else: + files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()] + + total_files = len(files_to_hash) + + if files_to_verify: + report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0) + else: + report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0) + + # Progress wrapper for hashing + completed_count = [0] # Use list for mutable closure + + def hash_progress_callback(msg: str): + if "Using" in msg and "workers" in msg: + report_progress(msg) + elif "[" in msg and "/" in msg and "]" in msg: + # Progress update like: [1/10] ✓ filename (123.4 MB) + completed_count[0] += 1 + report_progress(msg, total=total_files, current=completed_count[0]) + + # Pass the pre-filtered files_to_hash list + local_hashes = calculate_local_sha256( + local_dir_path, + "*.safetensors", + progress_callback=hash_progress_callback, + files_list=files_to_hash if files_to_verify else None, + ) + report_progress(f"✓ Calculated {len(local_hashes)} local file hashes") + + # 3. Compare hashes + report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0) + + files_failed = [] + files_missing = [] + files_passed = 0 + + for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1): + file_basename = Path(filename).name + + # Find matching local file + local_hash = None + for local_file, local_hash_value in local_hashes.items(): + if Path(local_file).name == file_basename: + local_hash = local_hash_value + break + + if local_hash is None: + files_missing.append(filename) + if verbose: + report_progress( + f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: ", + total=len(official_hashes), + current=idx, + ) + else: + report_progress( + f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)", + total=len(official_hashes), + current=idx, + ) + elif local_hash.lower() != official_hash.lower(): + files_failed.append(f"{filename} (hash mismatch)") + if verbose: + report_progress( + f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}", + total=len(official_hashes), + current=idx, + ) + else: + report_progress( + f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)", + total=len(official_hashes), + current=idx, + ) + else: + files_passed += 1 + if verbose: + report_progress( + f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}", + total=len(official_hashes), + current=idx, + ) + else: + report_progress( + f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx + ) + + # 4. Return results + total_checked = len(official_hashes) + + if files_failed or files_missing: + all_failed = files_failed + [f"{f} (missing)" for f in files_missing] + return { + "status": "failed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": all_failed, + "error_message": f"{len(all_failed)} file(s) failed verification", + } + else: + return { + "status": "passed", + "files_checked": total_checked, + "files_passed": files_passed, + "files_failed": [], + } + + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.RequestException, + TimeoutError, # Socket timeout from socket.setdefaulttimeout() + OSError, # Network-related OS errors + ) as e: + error_msg = f"Network error: {str(e)}" + if repo_type == "huggingface": + error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com" + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": error_msg, + "is_network_error": True, + } + except Exception as e: + return { + "status": "error", + "files_checked": 0, + "files_passed": 0, + "files_failed": [], + "error_message": f"Verification failed: {str(e)}", + "is_network_error": False, + } + + +def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None: + """Pre-operation verification of model integrity. + + Can be used before running or quantizing models to ensure integrity. + + Args: + user_model: UserModel object to verify + user_registry: UserModelRegistry instance + operation_name: Name of the operation (e.g., "running", "quantizing") + """ + from rich.prompt import Prompt, Confirm + from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn + from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError + from kt_kernel.cli.i18n import get_lang + from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step + import typer + + lang = get_lang() + + # Check if already verified + if user_model.sha256_status == "passed": + console.print() + print_info("Model integrity already verified ✓") + console.print() + return + + # Model not verified yet + console.print() + console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]") + console.print() + + # Check if repo_id exists + if not user_model.repo_id: + # No repo_id - ask user to provide one + console.print("[yellow]No repository ID configured for this model.[/yellow]") + console.print() + console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')") + console.print() + + if not Confirm.ask("Would you like to configure repository ID now?", default=True): + console.print() + print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.") + console.print() + return + + # Ask for repo type + console.print() + console.print("Repository type:") + console.print(" [cyan][1][/cyan] HuggingFace") + console.print(" [cyan][2][/cyan] ModelScope") + console.print() + + repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1") + repo_type = "huggingface" if repo_type_choice == "1" else "modelscope" + + # Ask for repo_id + console.print() + repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)") + + # Update model + user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id}) + user_model.repo_type = repo_type + user_model.repo_id = repo_id + + console.print() + print_success(f"Repository configured: {repo_type}:{repo_id}") + console.print() + + # Now ask if user wants to verify + console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]") + console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]") + console.print() + + if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True): + console.print() + print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.") + console.print() + return + + # Perform verification + console.print() + print_step("Verifying model integrity...") + console.print() + + # Check connectivity first + use_mirror = False + if user_model.repo_type == "huggingface": + with console.status("[dim]Checking HuggingFace connectivity...[/dim]"): + is_accessible, message = check_huggingface_connectivity(timeout=5) + + if not is_accessible: + print_warning("HuggingFace Connection Failed") + console.print() + console.print(f" {message}") + console.print() + console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]") + console.print() + use_mirror = True + + # Fetch remote hashes with timeout + def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout): + """Fetch hashes with timeout.""" + executor = ThreadPoolExecutor(max_workers=1) + try: + platform = "hf" if repo_type == "huggingface" else "ms" + future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout) + hashes = future.result(timeout=timeout) + executor.shutdown(wait=False) + return (hashes, False) + except (FutureTimeoutError, Exception): + executor.shutdown(wait=False) + return (None, True) + + # Try fetching hashes + status = console.status("[dim]Fetching remote hashes...[/dim]") + status.start() + official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10) + status.stop() + + # Handle timeout with fallback + if timed_out and user_model.repo_type == "huggingface" and not use_mirror: + print_warning("HuggingFace Fetch Timeout (10s)") + console.print() + console.print(" [yellow]Trying HuggingFace mirror...[/yellow]") + console.print() + + status = console.status("[dim]Fetching remote hashes from mirror...[/dim]") + status.start() + official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10) + status.stop() + + if timed_out and user_model.repo_type == "huggingface": + print_warning("HuggingFace Mirror Timeout (10s)") + console.print() + console.print(" [yellow]Fallback to ModelScope...[/yellow]") + console.print() + + status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]") + status.start() + official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10) + status.stop() + + if not official_hashes or timed_out: + print_error("Failed to fetch remote hashes (network timeout)") + console.print() + console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]") + console.print() + + if not Confirm.ask(f"Continue {operation_name} without verification?", default=False): + raise typer.Exit(0) + + console.print() + return + + console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]") + console.print() + + # Calculate local hashes and compare + local_dir = Path(user_model.path) + files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + console=console, + ) as progress: + # Calculate local hashes + task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash)) + + def hash_callback(msg): + if "[" in msg and "/" in msg and "]" in msg and "✓" in msg: + progress.advance(task) + + local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback) + progress.remove_task(task) + + console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]") + console.print() + + # Compare hashes + task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes)) + + files_failed = [] + files_missing = [] + files_passed = 0 + + for filename, official_hash in official_hashes.items(): + file_basename = Path(filename).name + local_hash = None + + for local_file, local_hash_value in local_hashes.items(): + if Path(local_file).name == file_basename: + local_hash = local_hash_value + break + + if local_hash is None: + files_missing.append(filename) + elif local_hash.lower() != official_hash.lower(): + files_failed.append(f"{filename} (hash mismatch)") + else: + files_passed += 1 + + progress.advance(task) + + progress.remove_task(task) + + console.print() + + # Check results + if not files_failed and not files_missing: + # Verification passed + user_registry.update_model(user_model.name, {"sha256_status": "passed"}) + print_success("Model integrity verification PASSED ✓") + console.print() + console.print(f" All {files_passed} files verified successfully") + console.print() + else: + # Verification failed + user_registry.update_model(user_model.name, {"sha256_status": "failed"}) + print_error(f"Model integrity verification FAILED") + console.print() + console.print(f" ✓ Passed: [green]{files_passed}[/green]") + console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]") + console.print() + + if files_missing: + console.print(f" [red]Missing files ({len(files_missing)}):[/red]") + for f in files_missing[:5]: + console.print(f" - {Path(f).name}") + if len(files_missing) > 5: + console.print(f" ... and {len(files_missing) - 5} more") + console.print() + + if files_failed: + console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]") + for f in files_failed[:5]: + console.print(f" - {f}") + if len(files_failed) > 5: + console.print(f" ... and {len(files_failed) - 5} more") + console.print() + + console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]") + console.print() + console.print("This could cause runtime errors or incorrect inference results.") + console.print() + + # Ask if user wants to repair + if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True): + console.print() + print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]") + console.print() + console.print("The verify command will guide you through the repair process.") + raise typer.Exit(0) + + # Ask if user wants to continue anyway + console.print() + if not Confirm.ask( + f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False + ): + raise typer.Exit(0) + + console.print() + print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...") + console.print() diff --git a/kt-kernel/python/cli/utils/port_checker.py b/kt-kernel/python/cli/utils/port_checker.py new file mode 100644 index 0000000..ffdf209 --- /dev/null +++ b/kt-kernel/python/cli/utils/port_checker.py @@ -0,0 +1,57 @@ +""" +Port availability checking utilities. +""" + +import socket +from typing import Tuple + + +def is_port_available(host: str, port: int) -> bool: + """Check if a port is available on the given host. + + Args: + host: Host address (e.g., "0.0.0.0", "127.0.0.1") + port: Port number to check + + Returns: + True if port is available, False if occupied + """ + try: + # Try to bind to the port + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + + # Use SO_REUSEADDR to allow binding to recently closed ports + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # Try to bind + result = sock.connect_ex((host if host != "0.0.0.0" else "127.0.0.1", port)) + sock.close() + + # If connect_ex returns 0, port is occupied + # If it returns error (non-zero), port is available + return result != 0 + + except Exception: + # If any error occurs, assume port is not available + return False + + +def find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]: + """Find an available port starting from start_port. + + Args: + host: Host address + start_port: Starting port number to check + max_attempts: Maximum number of ports to try + + Returns: + Tuple of (found, port_number) + - found: True if an available port was found + - port_number: The available port number (or start_port if not found) + """ + for port in range(start_port, start_port + max_attempts): + if is_port_available(host, port): + return True, port + + return False, start_port diff --git a/kt-kernel/python/cli/utils/quant_interactive.py b/kt-kernel/python/cli/utils/quant_interactive.py new file mode 100644 index 0000000..f725047 --- /dev/null +++ b/kt-kernel/python/cli/utils/quant_interactive.py @@ -0,0 +1,347 @@ +""" +Interactive configuration for kt quant command. + +Provides rich, multi-step interactive configuration for model quantization. +""" + +from typing import Optional, Dict, Any +from pathlib import Path +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.prompt import Prompt, Confirm, IntPrompt +from kt_kernel.cli.i18n import t + + +console = Console() + + +def select_model_to_quantize() -> Optional[Any]: + """Select model to quantize interactively.""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP + from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table + + registry = UserModelRegistry() + all_models = registry.list_models() + + # Filter MoE models only (safetensors, not AMX, is_moe=True) + quant_models = [] + for model in all_models: + if model.format == "safetensors": + # Skip AMX models + is_amx, _ = is_amx_weights(model.path) + if is_amx: + continue + + # Only include MoE models + if model.is_moe: + quant_models.append(model) + + if not quant_models: + console.print(f"[yellow]{t('quant_no_moe_models')}[/yellow]") + console.print() + console.print(f" {t('quant_only_moe')}") + console.print() + console.print(f" {t('quant_add_models', command='kt model scan')}") + console.print(f" {t('quant_add_models', command='kt model add ')}") + return None + + # Display models + console.print() + console.print(f"[bold green]{t('quant_moe_available')}[/bold green]") + console.print() + + # Use shared table builder + table, displayed_models = build_moe_gpu_table( + models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1 + ) + + console.print(table) + console.print() + + choice = IntPrompt.ask(t("quant_select_model"), default=1, show_choices=False) + + if choice < 1 or choice > len(displayed_models): + console.print(f"[red]{t('quant_invalid_choice')}[/red]") + return None + + return displayed_models[choice - 1] + + +def configure_quantization_method() -> Dict[str, str]: + """Select quantization method and input type.""" + console.print() + console.print(Panel(f"[bold cyan]{t('quant_step2_method')}[/bold cyan]", expand=False)) + console.print() + + # Method selection + console.print(f"[bold]{t('quant_method_label')}[/bold]") + console.print(f" [cyan][1][/cyan] {t('quant_int4_desc')}") + console.print(f" [cyan][2][/cyan] {t('quant_int8_desc')}") + console.print() + + method_choice = Prompt.ask(t("quant_select_method"), choices=["1", "2"], default="1") + method = "int4" if method_choice == "1" else "int8" + + console.print() + console.print(f"[bold]{t('quant_input_type_label')}[/bold]") + console.print(f" [cyan][1][/cyan] {t('quant_fp8_desc')}") + console.print(f" [cyan][2][/cyan] {t('quant_fp16_desc')}") + console.print(f" [cyan][3][/cyan] {t('quant_bf16_desc')}") + console.print() + + input_choice = Prompt.ask(t("quant_select_input_type"), choices=["1", "2", "3"], default="1") + input_type_map = {"1": "fp8", "2": "fp16", "3": "bf16"} + input_type = input_type_map[input_choice] + + return {"method": method, "input_type": input_type} + + +def configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]: + """Configure CPU parameters.""" + console.print() + console.print(Panel(f"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]", expand=False)) + console.print() + + def clamp(value: int, min_val: int, max_val: int, default: int) -> int: + """Clamp value to range or return default if out of bounds.""" + if min_val <= value <= max_val: + return max(min_val, min(value, max_val)) + return default + + default_threads = int(max_cores * 0.8) + cpu_threads = IntPrompt.ask(t("quant_cpu_threads_prompt", max=max_cores), default=default_threads) + cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads) + + numa_nodes = IntPrompt.ask(t("quant_numa_nodes_prompt", max=max_numa), default=max_numa) + numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa) + + # Ask about GPU usage + console.print() + console.print(f"[bold]{t('quant_use_gpu_label')}[/bold]") + console.print(f" [dim]{t('quant_gpu_speedup')}[/dim]") + console.print() + use_gpu = Confirm.ask(t("quant_enable_gpu"), default=True) + + return {"cpu_threads": cpu_threads, "numa_nodes": numa_nodes, "use_gpu": use_gpu} + + +def configure_output_path(model: Any, method: str, numa_nodes: int) -> Path: + """Configure output path for quantized weights.""" + from kt_kernel.cli.config.settings import get_settings + + console.print() + console.print(Panel(f"[bold cyan]{t('quant_step4_output')}[/bold cyan]", expand=False)) + console.print() + + # Generate default output path + model_path = Path(model.path) + method_upper = method.upper() + settings = get_settings() + + # Priority: paths.weights > paths.models[0] > model's parent directory + weights_dir = settings.weights_dir + if weights_dir and weights_dir.exists(): + # Use configured weights directory (highest priority) + default_output = weights_dir / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}" + else: + # Use first model storage path + model_paths = settings.get_model_paths() + if model_paths and model_paths[0].exists(): + default_output = model_paths[0] / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}" + else: + # Fallback to model's parent directory + default_output = model_path.parent / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}" + + console.print(f"[dim]{t('quant_default_path')}[/dim]", default_output) + console.print() + + use_default = Confirm.ask(t("quant_use_default"), default=True) + + if use_default: + return default_output + + custom_path = Prompt.ask(t("quant_custom_path"), default=str(default_output)) + + return Path(custom_path) + + +def calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]: + """ + Calculate source model size and estimated quantized size. + + Args: + source_path: Path to source model + input_type: Input type (fp8, fp16, bf16) + quant_method: Quantization method (int4, int8) + + Returns: + Tuple of (source_size_gb, estimated_quant_size_gb) + """ + # Calculate source model size + try: + total_bytes = sum(f.stat().st_size for f in source_path.glob("*.safetensors") if f.is_file()) + source_size_gb = total_bytes / (1024**3) + except Exception: + return 0.0, 0.0 + + # Bits mapping + input_bits = {"fp8": 8, "fp16": 16, "bf16": 16} + quant_bits = {"int4": 4, "int8": 8} + + input_bit = input_bits.get(input_type, 16) + quant_bit = quant_bits.get(quant_method, 4) + + # Estimate: source_size * (quant_bits / input_bits) + ratio = quant_bit / input_bit + estimated_size_gb = source_size_gb * ratio + + return source_size_gb, estimated_size_gb + + +def check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]: + """ + Check available disk space at output path. + + Args: + output_path: Target output path + required_size_gb: Required space in GB + + Returns: + Tuple of (available_gb, is_sufficient) + is_sufficient is True if available >= required * 1.2 + """ + import shutil + + try: + # Get parent directory that exists + check_path = output_path.parent if not output_path.exists() else output_path + while not check_path.exists() and check_path != check_path.parent: + check_path = check_path.parent + + stat = shutil.disk_usage(check_path) + available_gb = stat.free / (1024**3) + + # Check if available space >= required * 1.2 (20% buffer) + is_sufficient = available_gb >= (required_size_gb * 1.2) + + return available_gb, is_sufficient + except Exception: + return 0.0, False + + +def interactive_quant_config() -> Optional[Dict[str, Any]]: + """ + Interactive configuration for kt quant. + + Returns configuration dict or None if cancelled. + """ + from kt_kernel.cli.utils.environment import detect_cpu_info + + # Get CPU info + cpu_info = detect_cpu_info() + + # Step 1: Select model + model = select_model_to_quantize() + if not model: + return None + + # Step 1.5: Pre-quantization verification (optional) + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.utils.model_verifier import pre_operation_verification + + user_registry = UserModelRegistry() + user_model_obj = user_registry.find_by_path(model.path) + + if user_model_obj and user_model_obj.format == "safetensors": + pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing") + + # Step 2: Configure quantization method + quant_config = configure_quantization_method() + + # Step 3: Configure CPU parameters + cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes) # Use logical threads + + # Step 4: Configure output path + output_path = configure_output_path(model, quant_config["method"], cpu_config["numa_nodes"]) + + # Step 4.5: Check if output path already exists and generate unique name + if output_path.exists(): + console.print() + console.print(t("quant_output_exists_warn", path=str(output_path))) + console.print() + + # Generate unique name by adding suffix + original_name = output_path.name + parent_dir = output_path.parent + counter = 2 + + while output_path.exists(): + new_name = f"{original_name}-{counter}" + output_path = parent_dir / new_name + counter += 1 + + console.print(t("quant_using_unique_name", path=str(output_path))) + console.print() + + # Step 5: Calculate space requirements and check availability + console.print() + console.print(Panel(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]", expand=False)) + console.print() + + source_size_gb, estimated_size_gb = calculate_quantized_size( + Path(model.path), quant_config["input_type"], quant_config["method"] + ) + + available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb) + + console.print(f" {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]") + console.print(f" {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]") + console.print( + f" {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]" + ) + console.print() + + if not is_sufficient: + required_with_buffer = estimated_size_gb * 1.2 + console.print(f"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]") + console.print() + console.print(f" {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]") + console.print(f" {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]") + console.print(f" {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]") + console.print() + console.print(f" {t('quant_may_fail')}") + console.print() + + if not Confirm.ask(f"[yellow]{t('quant_continue_anyway')}[/yellow]", default=False): + console.print(f"[yellow]{t('quant_cancelled')}[/yellow]") + return None + console.print() + + # Summary and confirmation + console.print() + console.print(Panel(f"[bold cyan]{t('quant_config_summary')}[/bold cyan]", expand=False)) + console.print() + console.print(f" {t('quant_summary_model'):<15} {model.name}") + console.print(f" {t('quant_summary_method'):<15} {quant_config['method'].upper()}") + console.print(f" {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}") + console.print(f" {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}") + console.print(f" {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}") + console.print(f" {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}") + console.print(f" {t('quant_summary_output'):<15} {output_path}") + console.print() + + if not Confirm.ask(f"[bold green]{t('quant_start_question')}[/bold green]", default=True): + console.print(f"[yellow]{t('quant_cancelled')}[/yellow]") + return None + + return { + "model": model, + "method": quant_config["method"], + "input_type": quant_config["input_type"], + "cpu_threads": cpu_config["cpu_threads"], + "numa_nodes": cpu_config["numa_nodes"], + "use_gpu": cpu_config["use_gpu"], + "output_path": output_path, + } diff --git a/kt-kernel/python/cli/utils/repo_detector.py b/kt-kernel/python/cli/utils/repo_detector.py new file mode 100644 index 0000000..6042f09 --- /dev/null +++ b/kt-kernel/python/cli/utils/repo_detector.py @@ -0,0 +1,364 @@ +""" +Repo Detector + +Automatically detect repository information from model README.md files +""" + +import re +from pathlib import Path +from typing import Optional, Dict, Tuple +import yaml + + +def parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]: + """ + Parse YAML frontmatter from README.md + + Args: + readme_path: Path to README.md file + + Returns: + Dictionary of frontmatter data, or None if not found + """ + if not readme_path.exists(): + return None + + try: + with open(readme_path, "r", encoding="utf-8") as f: + content = f.read() + + # Match YAML frontmatter between --- markers + match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) + if not match: + return None + + yaml_content = match.group(1) + + # Parse YAML + try: + data = yaml.safe_load(yaml_content) + return data if isinstance(data, dict) else None + except yaml.YAMLError: + return None + + except Exception as e: + return None + + +def extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]: + """ + Extract repo_id and repo_type from frontmatter + + Args: + frontmatter: Parsed YAML frontmatter dictionary + + Returns: + Tuple of (repo_id, repo_type) or None + repo_type is either "huggingface" or "modelscope" + """ + if not frontmatter: + return None + + # Priority 1: Extract from license_link (most reliable) + license_link = frontmatter.get("license_link") + if license_link and isinstance(license_link, str): + result = _extract_repo_from_url(license_link) + if result: + return result + + # Priority 2: Try to find repo_id from other fields + repo_id = None + + # Check base_model field + base_model = frontmatter.get("base_model") + if base_model: + if isinstance(base_model, list) and len(base_model) > 0: + # base_model is a list, take first item + repo_id = base_model[0] + elif isinstance(base_model, str): + repo_id = base_model + + # Check model-index field + if not repo_id: + model_index = frontmatter.get("model-index") + if isinstance(model_index, list) and len(model_index) > 0: + first_model = model_index[0] + if isinstance(first_model, dict): + repo_id = first_model.get("name") + + # Check model_name field + if not repo_id: + repo_id = frontmatter.get("model_name") + + if not repo_id or not isinstance(repo_id, str): + return None + + # Validate format: should be "namespace/model-name" + if "/" not in repo_id: + return None + + parts = repo_id.split("/") + if len(parts) != 2: + return None + + # Determine repo type + repo_type = "huggingface" # Default + + # Look for ModelScope indicators + if "modelscope" in repo_id.lower(): + repo_type = "modelscope" + + # Check tags + tags = frontmatter.get("tags", []) + if isinstance(tags, list): + if "modelscope" in [str(t).lower() for t in tags]: + repo_type = "modelscope" + + return (repo_id, repo_type) + + +def _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]: + """ + Extract repo_id and repo_type from a URL + + Supports: + - https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE + - https://modelscope.cn/models/Qwen/Qwen3-30B-A3B + + Args: + url: URL string + + Returns: + Tuple of (repo_id, repo_type) or None + """ + # HuggingFace pattern: https://huggingface.co/{namespace}/{model}/... + hf_match = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", url) + if hf_match: + namespace = hf_match.group(1) + model_name = hf_match.group(2) + repo_id = f"{namespace}/{model_name}" + return (repo_id, "huggingface") + + # ModelScope pattern: https://modelscope.cn/models/{namespace}/{model} + ms_match = re.match(r"https?://(?:www\.)?modelscope\.cn/models/([^/]+)/([^/]+)", url) + if ms_match: + namespace = ms_match.group(1) + model_name = ms_match.group(2) + repo_id = f"{namespace}/{model_name}" + return (repo_id, "modelscope") + + return None + + +def extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]: + """ + Extract repo info by globally searching for URLs in README.md + + Args: + readme_path: Path to README.md file + + Returns: + Tuple of (repo_id, repo_type) or None if not found + """ + if not readme_path.exists(): + return None + + try: + with open(readme_path, "r", encoding="utf-8") as f: + content = f.read() + + # Find all HuggingFace URLs + hf_pattern = r"https?://huggingface\.co/([^/\s]+)/([^/\s\)]+)" + hf_matches = re.findall(hf_pattern, content) + + # Find all ModelScope URLs + ms_pattern = r"https?://(?:www\.)?modelscope\.cn/models/([^/\s]+)/([^/\s\)]+)" + ms_matches = re.findall(ms_pattern, content) + + # Collect all found repos with their types + found_repos = [] + + for namespace, model_name in hf_matches: + # Skip common non-repo paths + if namespace.lower() in ["docs", "blog", "spaces", "datasets"]: + continue + if model_name.lower() in ["tree", "blob", "raw", "resolve", "discussions"]: + continue + + repo_id = f"{namespace}/{model_name}" + found_repos.append((repo_id, "huggingface")) + + for namespace, model_name in ms_matches: + repo_id = f"{namespace}/{model_name}" + found_repos.append((repo_id, "modelscope")) + + if not found_repos: + return None + + # If multiple different repos found, use the last one + # First, deduplicate + seen = {} + for repo_id, repo_type in found_repos: + seen[repo_id] = repo_type # Will keep the last occurrence + + # Get the last unique repo + if seen: + # Use the last item from found_repos that's unique + last_unique = None + for repo_id, repo_type in found_repos: + if repo_id in seen: + last_unique = (repo_id, repo_type) + + return last_unique + + return None + + except Exception as e: + return None + + +def detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]: + """ + Detect repository information for a model + + Strategy: + Only extract from YAML frontmatter metadata in README.md + (Removed global URL search to avoid false positives) + + Args: + model_path: Path to model directory + + Returns: + Tuple of (repo_id, repo_type) or None if not detected + """ + model_dir = Path(model_path) + + if not model_dir.exists() or not model_dir.is_dir(): + return None + + # Look for README.md + readme_path = model_dir / "README.md" + if not readme_path.exists(): + return None + + # Only parse YAML frontmatter (no fallback to global search) + frontmatter = parse_readme_frontmatter(readme_path) + if frontmatter: + return extract_repo_from_frontmatter(frontmatter) + + return None + + +def scan_models_for_repo(model_list) -> Dict: + """ + Scan a list of models and detect repo information + + Args: + model_list: List of UserModel objects + + Returns: + Dictionary with scan results: + { + 'detected': [(model, repo_id, repo_type), ...], + 'not_detected': [model, ...], + 'skipped': [model, ...] # Already has repo_id + } + """ + results = {"detected": [], "not_detected": [], "skipped": []} + + for model in model_list: + # Skip if already has repo_id + if model.repo_id: + results["skipped"].append(model) + continue + + # Only process safetensors and gguf models + if model.format not in ["safetensors", "gguf"]: + results["skipped"].append(model) + continue + + # Try to detect repo + repo_info = detect_repo_for_model(model.path) + + if repo_info: + repo_id, repo_type = repo_info + results["detected"].append((model, repo_id, repo_type)) + else: + results["not_detected"].append(model) + + return results + + +def format_detection_report(results: Dict) -> str: + """ + Format scan results into a readable report + + Args: + results: Results from scan_models_for_repo() + + Returns: + Formatted string report + """ + lines = [] + + lines.append("=" * 80) + lines.append("Auto-Detection Report") + lines.append("=" * 80) + lines.append("") + + # Detected + if results["detected"]: + lines.append(f"✓ Detected repository information ({len(results['detected'])} models):") + lines.append("") + for model, repo_id, repo_type in results["detected"]: + lines.append(f" • {model.name}") + lines.append(f" Path: {model.path}") + lines.append(f" Repo: {repo_id} ({repo_type})") + lines.append("") + + # Not detected + if results["not_detected"]: + lines.append(f"✗ No repository information found ({len(results['not_detected'])} models):") + lines.append("") + for model in results["not_detected"]: + lines.append(f" • {model.name}") + lines.append(f" Path: {model.path}") + lines.append("") + + # Skipped + if results["skipped"]: + lines.append(f"⊘ Skipped ({len(results['skipped'])} models):") + lines.append(f" (Already have repo_id or not safetensors/gguf format)") + lines.append("") + + lines.append("=" * 80) + lines.append( + f"Summary: {len(results['detected'])} detected, " + f"{len(results['not_detected'])} not detected, " + f"{len(results['skipped'])} skipped" + ) + lines.append("=" * 80) + + return "\n".join(lines) + + +def apply_detection_results(results: Dict, registry) -> int: + """ + Apply detected repo information to models in registry + + Args: + results: Results from scan_models_for_repo() + registry: UserModelRegistry instance + + Returns: + Number of models updated + """ + updated_count = 0 + + for model, repo_id, repo_type in results["detected"]: + success = registry.update_model(model.name, {"repo_id": repo_id, "repo_type": repo_type}) + + if success: + updated_count += 1 + + return updated_count diff --git a/kt-kernel/python/cli/utils/run_configs.py b/kt-kernel/python/cli/utils/run_configs.py new file mode 100644 index 0000000..4077475 --- /dev/null +++ b/kt-kernel/python/cli/utils/run_configs.py @@ -0,0 +1,111 @@ +""" +Configuration save/load for kt run command. + +Manages saved run configurations bound to specific models. +""" + +from pathlib import Path +from typing import Dict, List, Optional, Any +from datetime import datetime +import yaml + + +CONFIG_FILE = Path.home() / ".ktransformers" / "run_configs.yaml" + + +class RunConfigManager: + """Manager for saved run configurations.""" + + def __init__(self): + self.config_file = CONFIG_FILE + self._ensure_config_file() + + def _ensure_config_file(self): + """Ensure config file exists.""" + if not self.config_file.exists(): + self.config_file.parent.mkdir(parents=True, exist_ok=True) + self._save_data({"version": "1.0", "configs": {}}) + + def _load_data(self) -> Dict: + """Load raw config data.""" + try: + with open(self.config_file, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {"version": "1.0", "configs": {}} + except Exception: + return {"version": "1.0", "configs": {}} + + def _save_data(self, data: Dict): + """Save raw config data.""" + with open(self.config_file, "w", encoding="utf-8") as f: + yaml.dump(data, f, allow_unicode=True, default_flow_style=False) + + def list_configs(self, model_id: str) -> List[Dict[str, Any]]: + """List all saved configs for a model. + + Returns: + List of config dicts with 'config_name' and other fields. + """ + data = self._load_data() + configs = data.get("configs", {}).get(model_id, []) + return configs if isinstance(configs, list) else [] + + def save_config(self, model_id: str, config: Dict[str, Any]): + """Save a configuration for a model. + + Args: + model_id: Model ID to bind config to + config: Configuration dict with all run parameters + """ + data = self._load_data() + + if "configs" not in data: + data["configs"] = {} + + if model_id not in data["configs"]: + data["configs"][model_id] = [] + + # Add timestamp + config["created_at"] = datetime.now().isoformat() + + # Append config + data["configs"][model_id].append(config) + + self._save_data(data) + + def delete_config(self, model_id: str, config_index: int) -> bool: + """Delete a saved configuration. + + Args: + model_id: Model ID + config_index: Index of config to delete (0-based) + + Returns: + True if deleted, False if not found + """ + data = self._load_data() + + if model_id not in data.get("configs", {}): + return False + + configs = data["configs"][model_id] + if config_index < 0 or config_index >= len(configs): + return False + + configs.pop(config_index) + self._save_data(data) + return True + + def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]: + """Get a specific saved configuration. + + Args: + model_id: Model ID + config_index: Index of config to get (0-based) + + Returns: + Config dict or None if not found + """ + configs = self.list_configs(model_id) + if config_index < 0 or config_index >= len(configs): + return None + return configs[config_index] diff --git a/kt-kernel/python/cli/utils/run_interactive.py b/kt-kernel/python/cli/utils/run_interactive.py new file mode 100644 index 0000000..d5dce51 --- /dev/null +++ b/kt-kernel/python/cli/utils/run_interactive.py @@ -0,0 +1,1084 @@ +""" +Interactive configuration for kt run command - New Implementation. + +Provides step-by-step interactive configuration for running models. +""" + +from typing import Optional, List, Dict, Any, Tuple +from pathlib import Path +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.prompt import Prompt, Confirm +from rich import box +import torch + +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.input_validators import ( + prompt_int_with_retry, + prompt_float_with_retry, + prompt_choice_with_retry, + prompt_int_list_with_retry, +) + + +console = Console() + + +def get_gpu_info() -> List[Dict[str, Any]]: + """Get real-time GPU information with free VRAM.""" + from kt_kernel.cli.utils.environment import detect_gpus + + gpus = detect_gpus() + gpu_info_list = [] + + for i, gpu in enumerate(gpus): + total_vram_gb = gpu.vram_gb + free_vram_gb = gpu.vram_gb # Default fallback + + # Try to get real-time free VRAM + if torch.cuda.is_available() and i < torch.cuda.device_count(): + try: + free_vram_bytes, total_vram_bytes = torch.cuda.mem_get_info(i) + free_vram_gb = free_vram_bytes / (1024**3) + total_vram_gb = total_vram_bytes / (1024**3) + except Exception: + pass # Use fallback values + + gpu_info_list.append( + { + "id": i, + "name": gpu.name, + "total_vram_gb": total_vram_gb, + "free_vram_gb": free_vram_gb, + } + ) + + return gpu_info_list + + +def select_model() -> Optional[Any]: + """Step 1: Select a safetensors MoE model. + + Returns: + Selected UserModel object or None if cancelled. + """ + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.commands.model import is_amx_weights + + registry = UserModelRegistry() + all_models = registry.list_models() + + # Filter: safetensors models only (exclude AMX and GGUF) + # Then filter to only show MoE models (matching kt model list behavior) + moe_models = [] + for model in all_models: + if model.format == "safetensors" and model.path_exists(): + is_amx, _ = is_amx_weights(model.path) + if not is_amx: + # Only include MoE models (is_moe == True) + # Also include models not yet analyzed (is_moe == None) for backwards compatibility + if model.is_moe is True or model.is_moe is None: + moe_models.append(model) + + if not moe_models: + console.print(f"[yellow]{t('run_int_no_moe_models')}[/yellow]") + console.print(f" {t('run_int_add_models')}") + console.print(f" {t('run_int_list_all')}") + return None + + console.print() + console.print(Panel(f"[bold cyan]{t('run_int_step1_title')}[/bold cyan]", expand=False)) + console.print() + + # Display models using same format as kt model list + from kt_kernel.cli.utils.model_scanner import format_size + from kt_kernel.cli.commands.model import SHA256_STATUS_MAP + + table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan") + table.add_column("#", justify="right", style="cyan", no_wrap=True) + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Path", style="dim", overflow="fold") + table.add_column("Total", justify="right") + table.add_column("Exps", justify="center", style="yellow") + table.add_column("Act", justify="center", style="green") + table.add_column("MoE Size", justify="right", style="cyan") + table.add_column("Repo", style="dim", overflow="fold") + table.add_column("SHA256", justify="center") + + for i, model in enumerate(moe_models, 1): + # Calculate size + if model.path_exists(): + path_obj = Path(model.path) + try: + files = list(path_obj.glob("*.safetensors")) + total_size = sum(f.stat().st_size for f in files if f.exists()) + size_display = format_size(total_size) + except: + size_display = "[dim]-[/dim]" + else: + size_display = "[dim]-[/dim]" + + # Format MoE info + experts = f"[yellow]{model.moe_num_experts}[/yellow]" if model.moe_num_experts else "[dim]-[/dim]" + active = f"[green]{model.moe_num_experts_per_tok}[/green]" if model.moe_num_experts_per_tok else "[dim]-[/dim]" + moe_size = f"[cyan]{size_display}[/cyan]" if model.moe_num_experts else "[dim]-[/dim]" + + # Format repo info + if model.repo_id: + repo_abbr = "hf" if model.repo_type == "huggingface" else "ms" + repo_display = f"{repo_abbr}:{model.repo_id}" + else: + repo_display = "[dim]-[/dim]" + + # Format SHA256 status + sha256_display = SHA256_STATUS_MAP.get(model.sha256_status, model.sha256_status) + + table.add_row( + str(i), + model.name, + str(model.path), + size_display, + experts, + active, + moe_size, + repo_display, + sha256_display, + ) + + console.print(table) + console.print() + + choice = prompt_int_with_retry( + t("run_int_select_model"), + default=1, + min_val=1, + max_val=len(moe_models), + ) + + return moe_models[choice - 1] + + +def select_inference_method(model: Any) -> Optional[Dict[str, Any]]: + """Step 2: Select inference method. + + Args: + model: Selected UserModel + + Returns: + Dict with 'method' (raw/amx/gguf/saved), and method-specific fields, or None if cancelled. + """ + from kt_kernel.cli.utils.run_configs import RunConfigManager + + config_manager = RunConfigManager() + saved_configs = config_manager.list_configs(model.id) + + # Debug output (can be removed later) + if False: # Set to True for debugging + console.print() + console.print(f"[dim]DEBUG: Model ID: {model.id}[/dim]") + console.print(f"[dim]DEBUG: Saved configs count: {len(saved_configs)}[/dim]") + if saved_configs: + console.print(f"[dim]DEBUG: Configs: {[c.get('config_name', '?') for c in saved_configs]}[/dim]") + console.print() + + console.print() + console.print(Panel("[bold cyan]Step 2: Select Inference Method[/bold cyan]", expand=False)) + console.print() + + options = [] + option_map = {} + + # Option 1: Use saved configuration (if any) + if saved_configs: + option_idx = len(options) + 1 + console.print(f" [cyan][{option_idx}][/cyan] [bold]Use Saved Configuration[/bold]") + console.print(f" [dim]{len(saved_configs)} saved config(s) available[/dim]") + options.append(str(option_idx)) + option_map[str(option_idx)] = "saved" + + # Option 2: Raw precision inference + option_idx = len(options) + 1 + console.print(f" [cyan][{option_idx}][/cyan] [bold]Raw Precision Inference[/bold]") + console.print(" [dim]FP8 / FP8_PERCHANNEL / BF16 / RAWINT4[/dim]") + options.append(str(option_idx)) + option_map[str(option_idx)] = "raw" + + # Option 3: AMX quantized inference + option_idx = len(options) + 1 + console.print(f" [cyan][{option_idx}][/cyan] [bold]AMX Quantized Inference[/bold]") + console.print(" [dim]INT4 / INT8 (CPU optimized)[/dim]") + options.append(str(option_idx)) + option_map[str(option_idx)] = "amx" + + # Option 4: GGUF inference + option_idx = len(options) + 1 + console.print(f" [cyan][{option_idx}][/cyan] [bold]GGUF Inference[/bold]") + console.print(" [dim]Llamafile format[/dim]") + options.append(str(option_idx)) + option_map[str(option_idx)] = "gguf" + + console.print() + + choice = prompt_choice_with_retry("Select method", choices=options, default="1") + method = option_map[choice] + + if method == "saved": + return _select_saved_config(model, saved_configs) + elif method == "raw": + return _configure_raw_inference(model) + elif method == "amx": + return _configure_amx_inference(model) + elif method == "gguf": + return _configure_gguf_inference(model) + + return None + + +def _select_saved_config(model: Any, saved_configs: List[Dict]) -> Optional[Dict[str, Any]]: + """Select from saved configurations with detailed display.""" + console.print() + console.print("[bold]Saved Configurations:[/bold]") + console.print() + + for i, cfg in enumerate(saved_configs, 1): + # Build method display + method_display = cfg.get("inference_method", "unknown").upper() + kt_method = cfg.get("kt_method", "unknown") + + if cfg.get("inference_method") == "raw": + raw_method = cfg.get("raw_method", "unknown") + method_display = f"{raw_method}" + elif cfg.get("inference_method") == "amx": + method_display = kt_method + elif cfg.get("inference_method") == "gguf": + method_display = "LLAMAFILE" + else: + method_display = kt_method + + # Display config header + console.print(f" [cyan][{i}][/cyan] [bold]{cfg.get('config_name', f'Config {i}')}[/bold]") + console.print() + + # Display detailed parameters + console.print(f" [yellow]KT Method:[/yellow] {method_display}") + console.print(f" [yellow]NUMA Nodes:[/yellow] {cfg.get('numa_nodes', '?')}") + console.print(f" [yellow]CPU Threads:[/yellow] {cfg.get('cpu_threads', '?')}") + console.print(f" [yellow]GPU Experts:[/yellow] {cfg.get('gpu_experts', '?')}") + console.print(f" [yellow]TP Size:[/yellow] {cfg.get('tp_size', '?')}") + console.print(f" [yellow]Memory Fraction:[/yellow] {cfg.get('mem_fraction_static', '?')}") + console.print(f" [yellow]Server:[/yellow] {cfg.get('host', '0.0.0.0')}:{cfg.get('port', 30000)}") + + # Display KV cache info if present + if cfg.get("kv_cache"): + console.print(f" [yellow]KV Cache:[/yellow] {cfg.get('kv_cache', '?')}") + console.print(f" [yellow]Chunk Prefill:[/yellow] {cfg.get('chunk_prefill', '?')}") + console.print(f" [yellow]GPU Prefill Thr:[/yellow] {cfg.get('gpu_prefill_threshold', '?')}") + + # Display parser info if present + if cfg.get("tool_call_parser") or cfg.get("reasoning_parser"): + if cfg.get("tool_call_parser"): + console.print(f" [yellow]Tool Call Parser:[/yellow] {cfg.get('tool_call_parser')}") + if cfg.get("reasoning_parser"): + console.print(f" [yellow]Reasoning Parser:[/yellow] {cfg.get('reasoning_parser')}") + + console.print() + + # Build and display command preview + cmd_preview = _build_command_preview(model, cfg) + console.print(" [dim]Command:[/dim]") + console.print() + for line in cmd_preview: + console.print(f" {line}") + console.print() + + choice = prompt_int_with_retry( + "Select configuration", + default=1, + min_val=1, + max_val=len(saved_configs), + ) + + selected_config = saved_configs[choice - 1].copy() + selected_config["method"] = "saved" + return selected_config + + +def _build_command_preview(model: Any, cfg: Dict[str, Any]) -> List[str]: + """Build command preview for saved configuration. + + Args: + model: UserModel object + cfg: Saved configuration dict + + Returns: + List of command lines for display + """ + import sys + + host = cfg.get("host", "0.0.0.0") + port = cfg.get("port", 30000) + + lines = [ + "python -m sglang.launch_server \\", + f" --host {host} \\", + f" --port {port} \\", + f" --model {cfg.get('model_path', '?')} \\", + f" --kt-weight-path {cfg.get('weights_path', '?')} \\", + f" --kt-cpuinfer {cfg.get('cpu_threads', '?')} \\", + f" --kt-threadpool-count {cfg.get('numa_nodes', '?')} \\", + f" --kt-num-gpu-experts {cfg.get('gpu_experts', '?')} \\", + f" --kt-method {cfg.get('kt_method', '?')} \\", + ] + + # Add GPU prefill threshold (use saved value or default) + gpu_prefill = cfg.get("gpu_prefill_threshold", 500) + lines.append(f" --kt-gpu-prefill-token-threshold {gpu_prefill} \\") + lines.append(" --kt-enable-dynamic-expert-update \\") + + # Add attention backend + lines.append(" --attention-backend flashinfer \\") + lines.append(" --trust-remote-code \\") + + # Add memory and performance settings + lines.append(f" --mem-fraction-static {cfg.get('mem_fraction_static', 0.9)} \\") + + # Add KV cache settings + chunk_prefill = cfg.get("chunk_prefill", 32768) + max_tokens = cfg.get("kv_cache", 32768) + lines.append(f" --chunked-prefill-size {chunk_prefill} \\") + lines.append(f" --max-total-tokens {max_tokens} \\") + + lines.append(" --max-running-requests 4 \\") + lines.append(" --watchdog-timeout 3000 \\") + lines.append(" --enable-mixed-chunk \\") + + # Add TP size (will be updated with actual GPU selection) + lines.append(f" --tensor-parallel-size {cfg.get('tp_size', '?')} \\") + lines.append(" --enable-p2p-check \\") + + # Add FP8 backend if using FP8 + kt_method = cfg.get("kt_method", "") + if "FP8" in kt_method.upper(): + lines.append(" --fp8-gemm-backend triton \\") + + # Add parsers if configured + if cfg.get("tool_call_parser"): + lines.append(f" --tool-call-parser {cfg['tool_call_parser']} \\") + if cfg.get("reasoning_parser"): + lines.append(f" --reasoning-parser {cfg['reasoning_parser']} \\") + + # Remove trailing backslash from last line + if lines: + lines[-1] = lines[-1].rstrip(" \\") + + return lines + + +def _configure_raw_inference(model: Any) -> Dict[str, Any]: + """Configure raw precision inference.""" + console.print() + console.print("[bold]Select Raw Precision Type:[/bold]") + console.print() + console.print(" [cyan][1][/cyan] FP8") + console.print(" [cyan][2][/cyan] FP8_PERCHANNEL") + console.print(" [cyan][3][/cyan] BF16") + console.print(" [cyan][4][/cyan] RAWINT4") + console.print() + + choice = prompt_choice_with_retry("Select precision", choices=["1", "2", "3", "4"], default="1") + + precision_map = { + "1": "FP8", + "2": "FP8_PERCHANNEL", + "3": "BF16", + "4": "RAWINT4", + } + + raw_method = precision_map[choice] + + return { + "method": "raw", + "raw_method": raw_method, + "kt_method": raw_method, + "model_path": model.path, + "weights_path": model.path, # Same as model path for raw + } + + +def _configure_amx_inference(model: Any) -> Optional[Dict[str, Any]]: + """Configure AMX quantized inference.""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + from kt_kernel.cli.commands.model import is_amx_weights + + registry = UserModelRegistry() + all_models = registry.list_models() + + # Filter AMX models + amx_models = [] + for m in all_models: + if m.format == "safetensors": + is_amx, numa = is_amx_weights(m.path) + if is_amx: + # Check if it's derived from the selected model + if m.amx_source_model == model.name: + amx_models.insert(0, m) # Prioritize matched models + else: + amx_models.append(m) + + if not amx_models: + console.print("[yellow]No AMX quantized models found.[/yellow]") + console.print(" Quantize your model with: [cyan]kt quant[/cyan]") + return None + + console.print() + console.print("[bold]Select AMX Weights:[/bold]") + console.print() + + for i, m in enumerate(amx_models, 1): + is_amx, numa = is_amx_weights(m.path) + method_str = m.amx_quant_method.upper() if m.amx_quant_method else "Unknown" + match_indicator = "[green]★[/green]" if m.amx_source_model == model.name else " " + console.print(f" {match_indicator} [cyan][{i}][/cyan] {m.name}") + console.print( + f" [dim]Method: AMX{method_str}, NUMA: {numa}, Source: {m.amx_source_model or 'Unknown'}[/dim]" + ) + + console.print() + choice = prompt_int_with_retry( + "Select AMX weights", + default=1, + min_val=1, + max_val=len(amx_models), + ) + + selected_amx = amx_models[choice - 1] + is_amx, numa = is_amx_weights(selected_amx.path) + kt_method = f"AMX{selected_amx.amx_quant_method.upper()}" if selected_amx.amx_quant_method else "AMXINT4" + + return { + "method": "amx", + "kt_method": kt_method, + "model_path": model.path, + "weights_path": selected_amx.path, + "amx_numa_nodes": numa, + } + + +def _configure_gguf_inference(model: Any) -> Optional[Dict[str, Any]]: + """Configure GGUF inference.""" + from kt_kernel.cli.utils.user_model_registry import UserModelRegistry + + registry = UserModelRegistry() + all_models = registry.list_models() + + # Filter GGUF models + gguf_models = [m for m in all_models if m.format == "gguf"] + + if not gguf_models: + console.print("[yellow]No GGUF models found.[/yellow]") + console.print(" Add GGUF models with: [cyan]kt model add /path/to/model.gguf[/cyan]") + return None + + console.print() + console.print("[bold]Select GGUF Weights:[/bold]") + console.print() + + for i, m in enumerate(gguf_models, 1): + console.print(f" [cyan][{i}][/cyan] {m.name}") + console.print(f" [dim]Path: {m.path}[/dim]") + + console.print() + choice = prompt_int_with_retry( + "Select GGUF weights", + default=1, + min_val=1, + max_val=len(gguf_models), + ) + + selected_gguf = gguf_models[choice - 1] + + return { + "method": "gguf", + "kt_method": "LLAMAFILE", + "model_path": model.path, + "weights_path": selected_gguf.path, + } + + +def configure_numa_and_cpu(method_config: Dict[str, Any]) -> Dict[str, int]: + """Step 3: Configure NUMA and CPU threads. + + Args: + method_config: Config from step 2 (may contain amx_numa_nodes hint) + + Returns: + Dict with 'numa_nodes' and 'cpu_threads' + """ + from kt_kernel.cli.utils.environment import detect_cpu_info + + cpu_info = detect_cpu_info() + max_numa = cpu_info.numa_nodes + max_cores = cpu_info.threads # Use logical threads instead of physical cores + + console.print() + console.print(Panel("[bold cyan]Step 3: NUMA and CPU Configuration[/bold cyan]", expand=False)) + console.print() + + # Show AMX hint if applicable + if method_config.get("method") == "amx" and method_config.get("amx_numa_nodes"): + amx_numa = method_config["amx_numa_nodes"] + console.print(f"[yellow]⚠ Note: This AMX model was quantized with NUMA={amx_numa}[/yellow]") + console.print(f"[yellow] For optimal performance, use the same NUMA setting.[/yellow]") + console.print() + default_numa = amx_numa + else: + default_numa = max_numa + + numa_nodes = prompt_int_with_retry( + f"NUMA Nodes (1 to {max_numa})", + default=default_numa, + min_val=1, + max_val=max_numa, + ) + + default_threads = int(max_cores * 0.8) + cpu_threads = prompt_int_with_retry( + f"CPU Threads (1 to {max_cores})", + default=default_threads, + min_val=1, + max_val=max_cores, + ) + + return { + "numa_nodes": numa_nodes, + "cpu_threads": cpu_threads, + } + + +def configure_gpu_experts(model: Any) -> int: + """Step 4: Configure GPU expert count. + + Args: + model: Selected model + + Returns: + Number of GPU experts + """ + from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model + + console.print() + console.print(Panel("[bold cyan]Step 4: GPU Experts Configuration[/bold cyan]", expand=False)) + console.print() + + # Try to get num_experts from model + try: + moe_result = analyze_moe_model(model.path) + num_experts = moe_result.get("num_experts", 256) + except Exception: + num_experts = 256 # Default fallback + + console.print(f"[dim]Model has {num_experts} experts total[/dim]") + console.print() + console.print("[yellow]⚠ Tip: More GPU experts = faster inference, but uses more VRAM[/yellow]") + console.print() + + default_experts = min(8, num_experts) + gpu_experts = prompt_int_with_retry( + f"GPU Experts per layer (0 to {num_experts})", + default=default_experts, + min_val=0, + max_val=num_experts, + ) + + return gpu_experts + + +def configure_kv_cache(is_raw_inference: bool) -> Optional[Dict[str, int]]: + """Step 5: Configure KV Cache (only for raw inference). + + Args: + is_raw_inference: True if using raw precision inference + + Returns: + Dict with 'kv_cache', 'chunk_prefill', 'gpu_prefill_threshold' or None if not applicable + """ + if not is_raw_inference: + return None + + console.print() + console.print(Panel("[bold cyan]Step 5: KV Cache and Prefill Configuration[/bold cyan]", expand=False)) + console.print() + console.print("[dim]These settings control memory allocation and prefill batch size[/dim]") + console.print("[dim]gpu-prefill-token-threshold: maximum length for single layerwise prefill[/dim]") + console.print() + + kv_cache = prompt_int_with_retry("KV Cache Size (max_total_tokens)", default=32768, min_val=1) + chunk_prefill = prompt_int_with_retry("Chunk Prefill Size", default=32768, min_val=1) + gpu_prefill_threshold = prompt_int_with_retry("GPU Prefill Token Threshold", default=500, min_val=1) + + return { + "kv_cache": kv_cache, + "chunk_prefill": chunk_prefill, + "gpu_prefill_threshold": gpu_prefill_threshold, + } + + +def select_gpus_and_tp( + required_tp_size: Optional[int] = None, saved_mem_fraction: Optional[float] = None +) -> Tuple[List[int], int, float]: + """Step 6: Select GPUs, TP size, and memory fraction. + + Args: + required_tp_size: If specified, user must select exactly this many GPUs. + If None, TP size can be any power of 2. + saved_mem_fraction: If specified, use this memory fraction instead of prompting. + Used when loading saved configurations. + + Returns: + Tuple of (selected_gpu_ids, tp_size, mem_fraction_static) + """ + gpu_info_list = get_gpu_info() + + if not gpu_info_list: + console.print("[red]No GPUs detected[/red]") + return [], 0, 0.9 + + console.print() + if required_tp_size is not None: + console.print(Panel(f"[bold cyan]Select {required_tp_size} GPUs (for saved config)[/bold cyan]", expand=False)) + console.print() + console.print(f"[yellow]Required TP size: {required_tp_size}[/yellow]") + console.print(f"[yellow]You must select exactly {required_tp_size} GPU(s)[/yellow]") + else: + console.print(Panel("[bold cyan]Step 6: GPU Selection and Memory[/bold cyan]", expand=False)) + console.print() + console.print("[dim]TP (Tensor Parallel) size must be a power of 2: 1, 2, 4, 8, ...[/dim]") + console.print() + + # Display GPUs + table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan") + table.add_column("ID", justify="right", style="cyan") + table.add_column("Name", style="white") + table.add_column("Free VRAM", justify="right", style="green") + table.add_column("Total VRAM", justify="right", style="dim") + + for gpu in gpu_info_list: + table.add_row(str(gpu["id"]), gpu["name"], f"{gpu['free_vram_gb']:.1f} GB", f"{gpu['total_vram_gb']:.1f} GB") + + console.print(table) + console.print() + + # Validator function + def validate_tp_requirements(gpu_ids: List[int]) -> tuple[bool, Optional[str]]: + """Validate TP requirements based on required_tp_size.""" + actual_count = len(gpu_ids) + + if required_tp_size is not None: + # Exact count required + if actual_count != required_tp_size: + return False, f"Must select exactly {required_tp_size} GPU(s), but you selected {actual_count}." + else: + # Must be power of 2 + if actual_count & (actual_count - 1) != 0: + return ( + False, + f"TP size ({actual_count}) must be a power of 2. Valid sizes: 1, 2, 4, 8, 16, 32, ...\nYou selected {actual_count} GPU(s). Please select a different number.", + ) + + return True, None + + # Generate default GPU selection + if required_tp_size is not None: + # For saved config: select first N GPUs + if required_tp_size <= len(gpu_info_list): + default_gpus = ",".join(str(i) for i in range(required_tp_size)) + else: + default_gpus = ",".join(str(i) for i in range(len(gpu_info_list))) + prompt_text = f"Enter {required_tp_size} GPU ID(s) separated by commas (e.g., 0,1,2,3)" + else: + # For new config: select all GPUs + default_gpus = ",".join(str(i) for i in range(len(gpu_info_list))) + prompt_text = "Enter GPU IDs separated by commas (e.g., 0,1,2,3)" + console.print(prompt_text) + console.print(f" Or press Enter to use all {len(gpu_info_list)} GPUs") + + console.print() + + selected_gpu_ids = prompt_int_list_with_retry( + "GPU IDs", + default=default_gpus, + min_val=0, + max_val=len(gpu_info_list) - 1, + validator=validate_tp_requirements, + ) + + tp_size = len(selected_gpu_ids) + + console.print() + console.print(f"[green]✓[/green] Selected {tp_size} GPU(s): {selected_gpu_ids}") + console.print() + + # Memory fraction - use saved value if provided, otherwise prompt + if saved_mem_fraction is not None: + mem_fraction = saved_mem_fraction + console.print(f"[dim]Using saved memory fraction: {mem_fraction}[/dim]") + else: + mem_fraction = prompt_float_with_retry( + "Static Memory Fraction (0.0-1.0)", + default=0.9, + min_val=0.0, + max_val=1.0, + ) + + return selected_gpu_ids, tp_size, mem_fraction + + +def configure_parsers() -> Dict[str, Optional[str]]: + """Step 7: Configure parsers (optional). + + Returns: + Dict with 'tool_call_parser' and 'reasoning_parser' (can be None) + """ + console.print() + console.print(Panel("[bold cyan]Step 7: Parser Configuration (Optional)[/bold cyan]", expand=False)) + console.print() + console.print("[dim]Press Enter to skip (no parser will be added)[/dim]") + console.print() + + tool_call_parser = Prompt.ask("Tool Call Parser (e.g., glm47)", default="") + tool_call_parser = tool_call_parser.strip() if tool_call_parser else None + + reasoning_parser = Prompt.ask("Reasoning Parser (e.g., glm45)", default="") + reasoning_parser = reasoning_parser.strip() if reasoning_parser else None + + if tool_call_parser or reasoning_parser: + console.print() + if tool_call_parser: + console.print(f"[green]✓[/green] Tool Call Parser: {tool_call_parser}") + if reasoning_parser: + console.print(f"[green]✓[/green] Reasoning Parser: {reasoning_parser}") + else: + console.print() + console.print("[dim]No parsers configured[/dim]") + + return { + "tool_call_parser": tool_call_parser, + "reasoning_parser": reasoning_parser, + } + + +def configure_host_and_port() -> Dict[str, Any]: + """Step 8: Configure host and port with availability check. + + Returns: + Dict with 'host' and 'port' + """ + from kt_kernel.cli.utils.port_checker import is_port_available + + console.print() + console.print(Panel("[bold cyan]Step 8: Server Configuration[/bold cyan]", expand=False)) + console.print() + + # Get host + host = Prompt.ask("Server Host", default="0.0.0.0") + + # Get port with availability check + while True: + port = prompt_int_with_retry( + "Server Port", + default=30000, + min_val=1024, + max_val=65535, + ) + + # Check if port is available + console.print() + console.print(f"[dim]Checking port {port} availability...[/dim]") + + if is_port_available(host, port): + console.print(f"[green]✓[/green] Port {port} is available") + break + else: + console.print(f"[red]✗[/red] Port {port} is already in use") + console.print() + + # Suggest next available port + from kt_kernel.cli.utils.port_checker import find_available_port + + found, suggested_port = find_available_port(host, port + 1, max_attempts=100) + if found: + console.print(f"[yellow]Suggestion:[/yellow] Port {suggested_port} is available") + console.print() + + console.print() + console.print(f"[green]✓[/green] Server will listen on {host}:{port}") + + return { + "host": host, + "port": port, + } + + +def save_config_prompt(model: Any, full_config: Dict[str, Any]) -> bool: + """Step 7: Prompt to save configuration. + + Args: + model: Selected model + full_config: Complete configuration dict + + Returns: + True if saved, False otherwise + """ + console.print() + console.print(Panel("[bold cyan]Step 7: Save Configuration[/bold cyan]", expand=False)) + console.print() + + if not Confirm.ask("Save this configuration for future use?", default=True): + return False + + config_name = Prompt.ask("Configuration name", default=f"Config {full_config.get('inference_method', 'default')}") + + from kt_kernel.cli.utils.run_configs import RunConfigManager + + config_manager = RunConfigManager() + + # Prepare config to save (exclude runtime-only fields and non-serializable objects) + save_config = { + "config_name": config_name, + "inference_method": full_config["inference_method"], + "kt_method": full_config["kt_method"], + "model_path": str(full_config["model_path"]), + "weights_path": str(full_config["weights_path"]), + "numa_nodes": full_config["numa_nodes"], + "cpu_threads": full_config["cpu_threads"], + "gpu_experts": full_config["gpu_experts"], + "tp_size": full_config["tp_size"], + "mem_fraction_static": full_config["mem_fraction_static"], + "host": full_config["host"], + "port": full_config["port"], + # Note: selected_gpus is NOT saved - user will select GPUs when loading config + } + + # Add parser config if present + if full_config.get("tool_call_parser"): + save_config["tool_call_parser"] = full_config["tool_call_parser"] + if full_config.get("reasoning_parser"): + save_config["reasoning_parser"] = full_config["reasoning_parser"] + + # Add raw-specific config if present + if full_config.get("raw_method"): + save_config["raw_method"] = full_config["raw_method"] + + if full_config.get("kv_cache"): + save_config["kv_cache"] = full_config["kv_cache"] + save_config["chunk_prefill"] = full_config["chunk_prefill"] + save_config["gpu_prefill_threshold"] = full_config["gpu_prefill_threshold"] + + config_manager.save_config(model.id, save_config) + + console.print() + console.print(f"[green]✓[/green] Configuration saved: {config_name}") + + return True + + +def interactive_run_config() -> Optional[Dict[str, Any]]: + """ + Main interactive configuration flow for kt run. + + Returns: + Complete configuration dict or None if cancelled. + """ + # Step 1: Select model + model = select_model() + if not model: + return None + + # Step 2: Select inference method + method_config = select_inference_method(model) + if not method_config: + return None + + # If using saved config, add model object and return directly + if method_config.get("method") == "saved": + console.print() + console.print("[green]✓[/green] Using saved configuration") + + # Let user select GPUs (must match saved TP size) + saved_tp_size = method_config.get("tp_size", 1) + + console.print() + console.print(f"[yellow]This configuration requires TP={saved_tp_size}[/yellow]") + console.print(f"[yellow]Please select {saved_tp_size} GPU(s)[/yellow]") + + # Get saved memory fraction + saved_mem_fraction = method_config.get("mem_fraction_static", 0.9) + + selected_gpus, actual_tp_size, _ = select_gpus_and_tp( + required_tp_size=saved_tp_size, saved_mem_fraction=saved_mem_fraction + ) + if not selected_gpus: + return None + + # Update config with selected GPUs (keep saved mem_fraction_static) + method_config["selected_gpus"] = selected_gpus + # tp_size is already in method_config from saved data + + # Check port availability + from kt_kernel.cli.utils.port_checker import is_port_available, find_available_port + + saved_host = method_config.get("host", "0.0.0.0") + saved_port = method_config.get("port", 30000) + + console.print() + console.print(f"[dim]Checking port {saved_port} availability...[/dim]") + + if is_port_available(saved_host, saved_port): + console.print(f"[green]✓[/green] Port {saved_port} is available") + method_config["port"] = saved_port + method_config["host"] = saved_host + else: + console.print(f"[red]✗[/red] Port {saved_port} is already in use") + console.print() + + # Suggest next available port + found, suggested_port = find_available_port(saved_host, saved_port + 1, max_attempts=100) + if found: + console.print(f"[yellow]Suggestion:[/yellow] Port {suggested_port} is available") + console.print() + + # Ask user for new port + while True: + new_port = prompt_int_with_retry( + "Enter new port", + default=suggested_port if found else saved_port + 1, + min_val=1024, + max_val=65535, + ) + + console.print() + console.print(f"[dim]Checking port {new_port} availability...[/dim]") + + if is_port_available(saved_host, new_port): + console.print(f"[green]✓[/green] Port {new_port} is available") + method_config["port"] = new_port + method_config["host"] = saved_host + break + else: + console.print(f"[red]✗[/red] Port {new_port} is already in use") + console.print() + + # Add model object for run.py compatibility + method_config["model"] = model + + # Ensure paths are Path objects + from pathlib import Path + + if "model_path" in method_config: + method_config["model_path"] = Path(method_config["model_path"]) + if "weights_path" in method_config: + method_config["weights_path"] = Path(method_config["weights_path"]) + + # Display configuration summary + console.print() + console.print(Panel("[bold cyan]Saved Configuration[/bold cyan]", expand=False)) + console.print() + _display_config_summary(method_config) + console.print() + + # Start directly without confirmation when using saved config + return method_config + + # Step 3: Configure NUMA and CPU + numa_cpu_config = configure_numa_and_cpu(method_config) + + # Step 4: Configure GPU experts + gpu_experts = configure_gpu_experts(model) + + # Step 5: Configure KV Cache (only for raw) + is_raw = method_config.get("method") == "raw" + kv_config = configure_kv_cache(is_raw) + + # Step 6: Select GPUs and TP + selected_gpus, tp_size, mem_fraction = select_gpus_and_tp() + if not selected_gpus: + return None + + # Step 7: Configure parsers (optional) + parser_config = configure_parsers() + + # Step 8: Configure host and port + server_config = configure_host_and_port() + + # Build complete configuration + full_config = { + "model": model, + "inference_method": method_config["method"], + "kt_method": method_config["kt_method"], + "model_path": method_config["model_path"], + "weights_path": method_config["weights_path"], + **numa_cpu_config, + "gpu_experts": gpu_experts, + "selected_gpus": selected_gpus, + "tp_size": tp_size, + "mem_fraction_static": mem_fraction, + **parser_config, # Add parser config + **server_config, # Add server config (host, port) + } + + # Add raw-specific config + if kv_config: + full_config["raw_method"] = method_config.get("raw_method") + full_config.update(kv_config) + + # Step 9: Save configuration + save_config_prompt(model, full_config) + + # Final confirmation + console.print() + console.print(Panel("[bold cyan]Configuration Complete[/bold cyan]", expand=False)) + console.print() + _display_config_summary(full_config) + console.print() + + if not Confirm.ask("[bold green]Start model server with this configuration?[/bold green]", default=True): + console.print("[yellow]Cancelled[/yellow]") + return None + + return full_config + + +def _display_config_summary(config: Dict[str, Any]): + """Display configuration summary.""" + model = config["model"] + console.print(f" Model: {model.name}") + console.print(f" KT Method: {config['kt_method']}") + console.print(f" NUMA Nodes: {config['numa_nodes']}") + console.print(f" CPU Threads: {config['cpu_threads']}") + console.print(f" GPU Experts: {config['gpu_experts']}") + + # Handle both new config and saved config format + tp_size = config.get("tp_size", len(config.get("selected_gpus", []))) + selected_gpus = config.get("selected_gpus", []) + + console.print(f" GPUs: {selected_gpus} (TP={tp_size})") + console.print(f" Memory Fraction: {config['mem_fraction_static']}") + + # Server config + host = config.get("host", "0.0.0.0") + port = config.get("port", 30000) + console.print(f" Server: {host}:{port}") + + if config.get("kv_cache"): + console.print(f" KV Cache: {config['kv_cache']}") + console.print(f" Chunk Prefill: {config['chunk_prefill']}") + console.print(f" GPU Prefill Thr: {config['gpu_prefill_threshold']}") + + # Display parsers if configured + if config.get("tool_call_parser") or config.get("reasoning_parser"): + console.print() + if config.get("tool_call_parser"): + console.print(f" Tool Call Parser: {config['tool_call_parser']}") + if config.get("reasoning_parser"): + console.print(f" Reasoning Parser: {config['reasoning_parser']}") diff --git a/kt-kernel/python/cli/utils/tuna_engine.py b/kt-kernel/python/cli/utils/tuna_engine.py new file mode 100644 index 0000000..9c622bd --- /dev/null +++ b/kt-kernel/python/cli/utils/tuna_engine.py @@ -0,0 +1,459 @@ +""" +Tuna engine for auto-tuning GPU experts configuration. + +Automatically finds the maximum viable num-gpu-experts through binary search +by testing actual server launches with different configurations. +""" + +import json +import math +import random +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +from kt_kernel.cli.utils.console import console, print_error, print_info, print_warning + + +def get_num_experts(model_path: Path) -> int: + """ + Get the number of experts per layer from model config. + + Args: + model_path: Path to the model directory + + Returns: + Number of experts per layer + + Raises: + ValueError: If config.json not found or num_experts field missing + """ + config_file = model_path / "config.json" + + if not config_file.exists(): + raise ValueError(f"config.json not found in {model_path}") + + try: + config = json.loads(config_file.read_text()) + except Exception as e: + raise ValueError(f"Failed to parse config.json: {e}") + + # Different models may use different field names + possible_keys = [ + "num_experts_per_tok", # DeepSeek + "num_local_experts", # Mixtral + "n_routed_experts", # Qwen + "num_experts", # Generic + ] + + for key in possible_keys: + if key in config: + return config[key] + + raise ValueError(f"Cannot find num_experts field in {config_file}. " f"Tried: {', '.join(possible_keys)}") + + +def detect_oom(log_line: Optional[str]) -> bool: + """ + Detect OOM (Out Of Memory) errors from log output. + + Args: + log_line: A line from server output + + Returns: + True if OOM detected, False otherwise + """ + if log_line is None: + return False + + log_lower = log_line.lower() + + oom_patterns = [ + "cuda out of memory", + "out of memory", + "outofmemoryerror", + "oom", + "failed to allocate", + "cumemalloc failed", + "cumemallocasync failed", + "allocation failed", + ] + + return any(pattern in log_lower for pattern in oom_patterns) + + +def test_config( + num_gpu_experts: int, + model_path: Path, + config: dict, + verbose: bool = False, +) -> tuple[bool, float]: + """ + Test if a configuration with given num_gpu_experts works. + + Args: + num_gpu_experts: Number of GPU experts to test + model_path: Path to the model + config: Configuration dict with all parameters + verbose: Whether to show detailed logs + + Returns: + (success: bool, elapsed_time: float) + - success: True if server starts and inference works + - elapsed_time: Time taken for the test + """ + start_time = time.time() + + # Use random port to avoid conflicts + test_port = random.randint(30000, 40000) + + # Build command + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model", + str(model_path), + "--port", + str(test_port), + "--host", + "127.0.0.1", + "--tensor-parallel-size", + str(config["tensor_parallel_size"]), + "--kt-num-gpu-experts", + str(num_gpu_experts), + "--max-total-tokens", + str(config["max_total_tokens"]), + ] + + # Add kt-kernel options + if config.get("weights_path"): + cmd.extend(["--kt-weight-path", str(config["weights_path"])]) + else: + cmd.extend(["--kt-weight-path", str(model_path)]) + + cmd.extend( + [ + "--kt-cpuinfer", + str(config.get("cpu_threads", 64)), + "--kt-threadpool-count", + str(config.get("numa_nodes", 2)), + "--kt-method", + config.get("kt_method", "AMXINT4"), + "--kt-gpu-prefill-token-threshold", + str(config.get("kt_gpu_prefill_threshold", 4096)), + ] + ) + + # Add other SGLang options + if config.get("attention_backend"): + cmd.extend(["--attention-backend", config["attention_backend"]]) + + cmd.extend( + [ + "--trust-remote-code", + "--mem-fraction-static", + str(config.get("mem_fraction_static", 0.98)), + "--chunked-prefill-size", + str(config.get("chunked_prefill_size", 4096)), + "--max-running-requests", + str(config.get("max_running_requests", 1)), # Use 1 for faster testing + "--watchdog-timeout", + str(config.get("watchdog_timeout", 3000)), + "--enable-mixed-chunk", + "--enable-p2p-check", + ] + ) + + # Add disable-shared-experts-fusion if specified + if config.get("disable_shared_experts_fusion"): + cmd.append("--disable-shared-experts-fusion") + + # Add extra args + if config.get("extra_args"): + cmd.extend(config["extra_args"]) + + if verbose: + console.print(f"[dim]Command: {' '.join(cmd)}[/dim]") + + # Start process + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=config.get("env"), + ) + except Exception as e: + if verbose: + print_error(f"Failed to start process: {e}") + return False, time.time() - start_time + + # Monitor process output + timeout = 60 # Maximum 60 seconds to wait + server_ready = False + + try: + while time.time() - start_time < timeout: + # Check if process has output + if process.poll() is not None: + # Process exited + if verbose: + print_warning("Process exited early") + return False, time.time() - start_time + + # Read output line (non-blocking) + try: + line = process.stdout.readline() + if not line: + time.sleep(0.1) + continue + + if verbose: + console.print(f"[dim]{line.rstrip()}[/dim]") + + # Fast OOM detection + if detect_oom(line): + if verbose: + print_warning(f"OOM detected: {line.rstrip()}") + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + process.kill() + return False, time.time() - start_time + + # Check for startup success + if "Uvicorn running" in line or "Application startup complete" in line: + server_ready = True + break + + except Exception as e: + if verbose: + print_warning(f"Error reading output: {e}") + break + + if not server_ready: + # Timeout or failed to start + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + process.kill() + return False, time.time() - start_time + + # Server is ready, test inference + success = test_inference(test_port, verbose=verbose) + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=2) + + return success, time.time() - start_time + + except KeyboardInterrupt: + # User cancelled + process.terminate() + try: + process.wait(timeout=2) + except subprocess.TimeoutExpired: + process.kill() + raise + except Exception as e: + if verbose: + print_error(f"Test failed with exception: {e}") + try: + process.terminate() + process.wait(timeout=2) + except: + try: + process.kill() + except: + pass + return False, time.time() - start_time + + +def test_inference(port: int, verbose: bool = False) -> bool: + """ + Test if the server can handle a simple inference request. + + Args: + port: Server port + verbose: Whether to show detailed logs + + Returns: + True if inference succeeds, False otherwise + """ + try: + # Wait a bit for server to be fully ready + time.sleep(2) + + # Try to import OpenAI client + try: + from openai import OpenAI + except ImportError: + if verbose: + print_warning("OpenAI package not available, skipping inference test") + return True # Assume success if we can't test + + client = OpenAI( + base_url=f"http://127.0.0.1:{port}/v1", + api_key="test", + ) + + # Send a simple test request + response = client.chat.completions.create( + model="test", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=1, + temperature=0, + timeout=10, + ) + + # Check if we got a valid response + success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None + + if verbose: + if success: + print_info(f"Inference test passed: {response.choices[0].message.content}") + else: + print_warning("Inference test failed: no valid response") + + return success + + except Exception as e: + if verbose: + print_warning(f"Inference test failed: {e}") + return False + + +def find_max_gpu_experts( + model_path: Path, + config: dict, + verbose: bool = False, +) -> int: + """ + Binary search to find the maximum viable num_gpu_experts. + + Args: + model_path: Path to the model + config: Configuration dict + verbose: Whether to show detailed logs + + Returns: + Maximum number of GPU experts that works + """ + # Get number of experts from model config + try: + num_experts = get_num_experts(model_path) + except ValueError as e: + print_error(str(e)) + raise + + console.print() + console.print(f"Binary search range: [0, {num_experts}]") + console.print() + + left, right = 0, num_experts + result = 0 + iteration = 0 + total_iterations = math.ceil(math.log2(num_experts + 1)) + + while left <= right: + iteration += 1 + mid = (left + right) // 2 + + console.print(f"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... ", end="") + + success, elapsed = test_config(mid, model_path, config, verbose=verbose) + + if success: + console.print(f"[green]✓ OK[/green] ({elapsed:.1f}s)") + result = mid + left = mid + 1 + else: + console.print(f"[red]✗ FAILED[/red] ({elapsed:.1f}s)") + right = mid - 1 + + return result + + +def run_tuna( + model_path: Path, + tensor_parallel_size: int, + max_total_tokens: int, + kt_method: str, + verbose: bool = False, + **kwargs, +) -> int: + """ + Run tuna auto-tuning to find optimal num_gpu_experts. + + Args: + model_path: Path to the model + tensor_parallel_size: Tensor parallel size + max_total_tokens: Maximum total tokens + kt_method: KT quantization method + verbose: Whether to show detailed logs + **kwargs: Additional configuration parameters + + Returns: + Optimal num_gpu_experts value + + Raises: + ValueError: If tuning fails completely + """ + # Prepare configuration + config = { + "tensor_parallel_size": tensor_parallel_size, + "max_total_tokens": max_total_tokens, + "kt_method": kt_method, + **kwargs, + } + + # Run binary search + try: + result = find_max_gpu_experts(model_path, config, verbose=verbose) + except KeyboardInterrupt: + console.print() + print_warning("Tuning cancelled by user") + raise + + console.print() + + # Check if even 0 doesn't work + if result == 0: + console.print("[yellow]Testing if gpu-experts=0 is viable...[/yellow]") + success, _ = test_config(0, model_path, config, verbose=verbose) + + if not success: + # Even 0 doesn't work + console.print() + print_error("Failed to start server even with all experts on CPU (gpu-experts=0)") + console.print() + console.print("[bold]Possible reasons:[/bold]") + console.print(" • Insufficient GPU memory for base model layers") + console.print(" • max-total-tokens is too large for available VRAM") + console.print(" • Tensor parallel configuration issue") + console.print() + console.print("[bold]Suggestions:[/bold]") + console.print(f" • Reduce --max-total-tokens (current: {max_total_tokens})") + console.print(f" • Reduce --tensor-parallel-size (current: {tensor_parallel_size})") + console.print(" • Use more GPUs or GPUs with more VRAM") + console.print(" • Try a smaller model") + console.print() + raise ValueError("Minimum GPU memory requirements not met") + else: + # 0 works but nothing more + console.print() + print_warning("All experts will run on CPU (gpu-experts=0). " "Performance will be limited by CPU speed.") + + return result diff --git a/kt-kernel/python/cli/utils/user_model_registry.py b/kt-kernel/python/cli/utils/user_model_registry.py new file mode 100644 index 0000000..ef3514e --- /dev/null +++ b/kt-kernel/python/cli/utils/user_model_registry.py @@ -0,0 +1,302 @@ +""" +User Model Registry + +Manages user-registered models in ~/.ktransformers/user_models.yaml +""" + +from dataclasses import dataclass, asdict, field +from datetime import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any +import yaml + + +# Constants +USER_MODELS_FILE = Path.home() / ".ktransformers" / "user_models.yaml" +REGISTRY_VERSION = "1.0" + + +@dataclass +class UserModel: + """Represents a user-registered model""" + + name: str # User-editable name (default: folder name) + path: str # Absolute path to model directory + format: str # "safetensors" | "gguf" + id: Optional[str] = None # Unique UUID for this model (auto-generated if None) + repo_type: Optional[str] = None # "huggingface" | "modelscope" | None + repo_id: Optional[str] = None # e.g., "deepseek-ai/DeepSeek-V3" + sha256_status: str = "not_checked" # "not_checked" | "checking" | "passed" | "failed" | "no_repo" + gpu_model_ids: Optional[List[str]] = None # For llamafile/AMX: list of GPU model UUIDs to run with + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + last_verified: Optional[str] = None # ISO format datetime + # MoE information (cached from analyze_moe_model) + is_moe: Optional[bool] = None # True if MoE model, False if non-MoE, None if not analyzed + moe_num_experts: Optional[int] = None # Total number of experts (for MoE models) + moe_num_experts_per_tok: Optional[int] = None # Number of active experts per token (for MoE models) + # AMX quantization metadata (for format == "amx") + amx_source_model: Optional[str] = None # Name of the source MoE model that was quantized + amx_quant_method: Optional[str] = None # "int4" | "int8" + amx_numa_nodes: Optional[int] = None # Number of NUMA nodes used for quantization + + def __post_init__(self): + """Ensure ID is set after initialization""" + if self.id is None: + import uuid + + self.id = str(uuid.uuid4()) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for YAML serialization""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UserModel": + """Create from dictionary loaded from YAML""" + return cls(**data) + + def path_exists(self) -> bool: + """Check if model path still exists""" + return Path(self.path).exists() + + +class UserModelRegistry: + """Manages the user model registry""" + + def __init__(self, registry_file: Optional[Path] = None): + """ + Initialize the registry + + Args: + registry_file: Path to the registry YAML file (default: USER_MODELS_FILE) + """ + self.registry_file = registry_file or USER_MODELS_FILE + self.models: List[UserModel] = [] + self.version = REGISTRY_VERSION + + # Ensure directory exists + self.registry_file.parent.mkdir(parents=True, exist_ok=True) + + # Load existing registry + self.load() + + def load(self) -> None: + """Load models from YAML file""" + if not self.registry_file.exists(): + # Initialize empty registry + self.models = [] + self.save() # Create the file + return + + try: + with open(self.registry_file, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not data: + self.models = [] + return + + # Load version + self.version = data.get("version", REGISTRY_VERSION) + + # Load models + models_data = data.get("models", []) + self.models = [UserModel.from_dict(m) for m in models_data] + + # Migrate: ensure all models have UUIDs (for backward compatibility) + needs_save = False + for model in self.models: + if model.id is None: + import uuid + + model.id = str(uuid.uuid4()) + needs_save = True + + if needs_save: + self.save() + + except Exception as e: + raise RuntimeError(f"Failed to load user model registry: {e}") + + def save(self) -> None: + """Save models to YAML file""" + data = {"version": self.version, "models": [m.to_dict() for m in self.models]} + + try: + with open(self.registry_file, "w", encoding="utf-8") as f: + yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False) + except Exception as e: + raise RuntimeError(f"Failed to save user model registry: {e}") + + def add_model(self, model: UserModel) -> None: + """ + Add a model to the registry + + Args: + model: UserModel instance to add + + Raises: + ValueError: If a model with the same name already exists + """ + if self.check_name_conflict(model.name): + raise ValueError(f"Model with name '{model.name}' already exists") + + self.models.append(model) + self.save() + + def remove_model(self, name: str) -> bool: + """ + Remove a model from the registry + + Args: + name: Name of the model to remove + + Returns: + True if model was removed, False if not found + """ + original_count = len(self.models) + self.models = [m for m in self.models if m.name != name] + + if len(self.models) < original_count: + self.save() + return True + return False + + def update_model(self, name: str, updates: Dict[str, Any]) -> bool: + """ + Update a model's attributes + + Args: + name: Name of the model to update + updates: Dictionary of attributes to update + + Returns: + True if model was updated, False if not found + """ + model = self.get_model(name) + if not model: + return False + + # Update attributes + for key, value in updates.items(): + if hasattr(model, key): + setattr(model, key, value) + + self.save() + return True + + def get_model(self, name: str) -> Optional[UserModel]: + """ + Get a model by name + + Args: + name: Name of the model + + Returns: + UserModel instance or None if not found + """ + for model in self.models: + if model.name == name: + return model + return None + + def get_model_by_id(self, model_id: str) -> Optional[UserModel]: + """ + Get a model by its unique ID + + Args: + model_id: UUID of the model + + Returns: + UserModel instance or None if not found + """ + for model in self.models: + if model.id == model_id: + return model + return None + + def list_models(self) -> List[UserModel]: + """ + List all models + + Returns: + List of all UserModel instances + """ + return self.models.copy() + + def find_by_path(self, path: str) -> Optional[UserModel]: + """ + Find a model by its path + + Args: + path: Model directory path + + Returns: + UserModel instance or None if not found + """ + # Normalize paths for comparison + search_path = str(Path(path).resolve()) + + for model in self.models: + model_path = str(Path(model.path).resolve()) + if model_path == search_path: + return model + return None + + def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool: + """ + Check if a name conflicts with existing models + + Args: + name: Name to check + exclude_name: Optional name to exclude from check (for rename operations) + + Returns: + True if conflict exists, False otherwise + """ + for model in self.models: + if model.name == name and model.name != exclude_name: + return True + return False + + def refresh_status(self) -> Dict[str, List[str]]: + """ + Check all models and identify missing ones + + Returns: + Dictionary with 'valid' and 'missing' lists of model names + """ + valid = [] + missing = [] + + for model in self.models: + if model.path_exists(): + valid.append(model.name) + else: + missing.append(model.name) + + return {"valid": valid, "missing": missing} + + def get_model_count(self) -> int: + """Get total number of registered models""" + return len(self.models) + + def suggest_name(self, base_name: str) -> str: + """ + Suggest a unique name based on base_name + + Args: + base_name: Base name to derive from + + Returns: + A unique name (may have suffix like -2, -3 etc.) + """ + if not self.check_name_conflict(base_name): + return base_name + + counter = 2 + while True: + candidate = f"{base_name}-{counter}" + if not self.check_name_conflict(candidate): + return candidate + counter += 1