mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-15 02:47:22 +00:00
kt-cli enhancement (#1834)
* [feat]: redesign kt run interactive configuration with i18n support - Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port) - Add configuration save/load system (~/.ktransformers/run_configs.yaml) - Add i18n support for kt chat (en/zh translations) - Add universal input validators with auto-retry and Chinese comma support - Add port availability checker with auto-suggestion - Add parser configuration (--tool-call-parser, --reasoning-parser) - Remove tuna command and clean up redundant files - Fix: variable reference bug in run.py, filter to show only MoE models * [feat]: unify model selection UI and enable shared experts fusion by default - Unify kt run model selection table with kt model list display * Add Total size, MoE Size, Repo, and SHA256 status columns * Use consistent formatting and styling * Improve user decision-making with more information - Enable --disable-shared-experts-fusion by default * Change default value from False to True * Users can still override with --enable-shared-experts-fusion * [feat]: improve kt chat with performance metrics and better CJK support - Add performance metrics display after each response * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token) * Accurate input/output token counts using model tokenizer * Fallback to estimation if tokenizer unavailable * Metrics shown in dim style (not prominent) - Fix Chinese character input issues * Replace Prompt.ask() with console.input() for better CJK support * Fixes backspace deletion showing half-characters - Suppress NumPy subnormal warnings * Filter "The value of the smallest subnormal" warnings * Cleaner CLI output on certain hardware environments * [fix]: correct TTFT measurement in kt chat - Move start_time initialization before API call - Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms - Now correctly measures time from request sent to first token received * [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案 * [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力 * [docs]: 添加 Clawdbot 飞书接入教程链接 * [feat]: improve CLI table display, model verification, and chat experience - Add sequence number (#) column to all model tables by default - Filter kt edit to show only MoE GPU models (exclude AMX) - Extend kt model verify to check *.json and *.py files in addition to weights - Fix re-verification bug where repaired files caused false failures - Suppress tokenizer debug output in kt chat token counting * [fix]: fix cpu cores. --------- Co-authored-by: skqliao <skqliao@gmail.com>
This commit is contained in:
@@ -96,9 +96,9 @@ def chat(
|
||||
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
|
||||
"""
|
||||
if not HAS_OPENAI:
|
||||
print_error("OpenAI Python SDK is required for chat functionality.")
|
||||
print_error(t("chat_openai_required"))
|
||||
console.print()
|
||||
console.print("Install it with:")
|
||||
console.print(t("chat_install_hint"))
|
||||
console.print(" pip install openai")
|
||||
raise typer.Exit(1)
|
||||
|
||||
@@ -114,10 +114,10 @@ def chat(
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
|
||||
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
|
||||
f"[bold cyan]{t('chat_title')}[/bold cyan]\n\n"
|
||||
f"{t('chat_server')}: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"{t('chat_temperature')}: [cyan]{temperature}[/cyan] | {t('chat_max_tokens')}: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]{t('chat_help_hint')}[/dim]",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
@@ -152,31 +152,44 @@ def chat(
|
||||
)
|
||||
|
||||
# Test connection
|
||||
print_info("Connecting to server...")
|
||||
print_info(t("chat_connecting"))
|
||||
models = client.models.list()
|
||||
available_models = [m.id for m in models.data]
|
||||
|
||||
if not available_models:
|
||||
print_error("No models available on server")
|
||||
print_error(t("chat_no_models"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Select model
|
||||
if model:
|
||||
if model not in available_models:
|
||||
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
|
||||
print_warning(t("chat_model_not_found", model=model, available=", ".join(available_models)))
|
||||
selected_model = available_models[0]
|
||||
else:
|
||||
selected_model = model
|
||||
else:
|
||||
selected_model = available_models[0]
|
||||
|
||||
print_success(f"Connected to model: {selected_model}")
|
||||
print_success(t("chat_connected", model=selected_model))
|
||||
console.print()
|
||||
|
||||
# Load tokenizer for accurate token counting
|
||||
tokenizer = None
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# selected_model is the model path
|
||||
tokenizer = AutoTokenizer.from_pretrained(selected_model, trust_remote_code=True)
|
||||
console.print(f"[dim]Loaded tokenizer from {selected_model}[/dim]")
|
||||
console.print()
|
||||
except Exception as e:
|
||||
console.print(f"[dim yellow]Warning: Could not load tokenizer, token counts will be estimated[/dim]")
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Failed to connect to server: {e}")
|
||||
print_error(t("chat_connect_failed", error=str(e)))
|
||||
console.print()
|
||||
console.print("Make sure the model server is running:")
|
||||
console.print(t("chat_server_not_running"))
|
||||
console.print(" kt run <model>")
|
||||
raise typer.Exit(1)
|
||||
|
||||
@@ -201,12 +214,12 @@ def chat(
|
||||
# Main chat loop
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
# Get user input - use console.input() for better CJK character support
|
||||
try:
|
||||
user_input = Prompt.ask("[bold green]You[/bold green]")
|
||||
user_input = console.input(f"[bold green]{t('chat_user_prompt')}[/bold green]: ")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
print_info(t("chat_goodbye"))
|
||||
break
|
||||
|
||||
if not user_input.strip():
|
||||
@@ -224,15 +237,19 @@ def chat(
|
||||
|
||||
# Generate response
|
||||
console.print()
|
||||
console.print("[bold cyan]Assistant[/bold cyan]")
|
||||
console.print(f"[bold cyan]{t('chat_assistant_prompt')}[/bold cyan]")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# Streaming response
|
||||
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
|
||||
response_content = _stream_response(
|
||||
client, selected_model, messages, temperature, max_tokens, tokenizer
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
|
||||
response_content = _generate_response(
|
||||
client, selected_model, messages, temperature, max_tokens, tokenizer
|
||||
)
|
||||
|
||||
# Add assistant response to history
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
@@ -240,7 +257,7 @@ def chat(
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Error generating response: {e}")
|
||||
print_error(t("chat_generation_error", error=str(e)))
|
||||
# Remove the user message that caused the error
|
||||
messages.pop()
|
||||
continue
|
||||
@@ -252,12 +269,12 @@ def chat(
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
console.print()
|
||||
print_info("Chat interrupted. Goodbye!")
|
||||
print_info(t("chat_interrupted"))
|
||||
|
||||
# Final history save
|
||||
if save_history and messages:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
console.print(f"[dim]History saved to: {history_file}[/dim]")
|
||||
console.print(f"[dim]{t('chat_history_saved', path=str(history_file))}[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
@@ -267,12 +284,22 @@ def _stream_response(
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tokenizer=None,
|
||||
) -> str:
|
||||
"""Generate streaming response and display in real-time."""
|
||||
import time
|
||||
|
||||
response_content = ""
|
||||
reasoning_content = ""
|
||||
|
||||
# Performance tracking
|
||||
first_token_time = None
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
# Start timing before sending request
|
||||
start_time = time.time()
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -282,33 +309,120 @@ def _stream_response(
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
reasoning_delta = getattr(delta, "reasoning_content", None)
|
||||
if reasoning_delta:
|
||||
reasoning_content += reasoning_delta
|
||||
console.print(reasoning_delta, end="", style="dim")
|
||||
if delta.content:
|
||||
content = delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if delta:
|
||||
reasoning_delta = getattr(delta, "reasoning_content", None)
|
||||
if reasoning_delta:
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
reasoning_content += reasoning_delta
|
||||
console.print(reasoning_delta, end="", style="dim")
|
||||
chunk_count += 1
|
||||
|
||||
if delta.content:
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
content = delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
chunk_count += 1
|
||||
|
||||
console.print() # Newline after streaming
|
||||
|
||||
# Display performance metrics
|
||||
end_time = time.time()
|
||||
if first_token_time and chunk_count > 0:
|
||||
ttft = first_token_time - start_time
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Calculate TPOT based on chunks
|
||||
if chunk_count > 1:
|
||||
generation_time = total_time - ttft
|
||||
tpot = generation_time / (chunk_count - 1)
|
||||
else:
|
||||
tpot = 0
|
||||
|
||||
# Calculate accurate token counts using tokenizer
|
||||
if tokenizer:
|
||||
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
|
||||
output_tokens = _count_tokens_with_tokenizer(
|
||||
[{"role": "assistant", "content": response_content}], tokenizer
|
||||
)
|
||||
token_prefix = ""
|
||||
else:
|
||||
# Fallback to estimation
|
||||
input_tokens = _estimate_tokens(messages)
|
||||
output_tokens = _estimate_tokens([{"role": "assistant", "content": response_content}])
|
||||
token_prefix = "~"
|
||||
|
||||
# Build metrics display
|
||||
metrics = f"[dim]Total: {total_time*1000:.0f}ms | TTFT: {ttft*1000:.0f}ms"
|
||||
if tpot > 0:
|
||||
metrics += f" | TPOT: {tpot*1000:.1f}ms"
|
||||
metrics += f" | In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}"
|
||||
metrics += "[/dim]"
|
||||
|
||||
console.print(metrics)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Streaming error: {e}")
|
||||
|
||||
return response_content
|
||||
|
||||
|
||||
def _count_tokens_with_tokenizer(messages: list, tokenizer) -> int:
|
||||
"""Count tokens accurately using the model's tokenizer."""
|
||||
try:
|
||||
# Concatenate all message content
|
||||
text = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
# Simple format: role + content
|
||||
text += f"{role}: {content}\n"
|
||||
|
||||
# Encode and count tokens - suppress any debug output from custom tokenizers
|
||||
import os
|
||||
import sys
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
|
||||
with open(os.devnull, "w") as devnull:
|
||||
with redirect_stdout(devnull), redirect_stderr(devnull):
|
||||
tokens = tokenizer.encode(text, add_special_tokens=True)
|
||||
return len(tokens)
|
||||
except Exception:
|
||||
# Fallback to estimation if tokenizer fails
|
||||
return _estimate_tokens(messages)
|
||||
|
||||
|
||||
def _estimate_tokens(messages: list) -> int:
|
||||
"""Estimate token count for messages (rough approximation)."""
|
||||
total_chars = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
total_chars += len(content)
|
||||
|
||||
# Rough estimation:
|
||||
# - English: ~4 chars per token
|
||||
# - Chinese: ~1.5 chars per token
|
||||
# Use 2.5 as average
|
||||
return max(1, int(total_chars / 2.5))
|
||||
|
||||
|
||||
def _generate_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tokenizer=None,
|
||||
) -> str:
|
||||
"""Generate non-streaming response."""
|
||||
import time
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -317,12 +431,36 @@ def _generate_response(
|
||||
stream=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Display as markdown
|
||||
md = Markdown(content)
|
||||
console.print(md)
|
||||
|
||||
# Calculate accurate token counts using tokenizer
|
||||
if tokenizer:
|
||||
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
|
||||
output_tokens = _count_tokens_with_tokenizer([{"role": "assistant", "content": content}], tokenizer)
|
||||
token_prefix = ""
|
||||
else:
|
||||
# Fallback to API usage or estimation
|
||||
input_tokens = response.usage.prompt_tokens if response.usage else _estimate_tokens(messages)
|
||||
output_tokens = (
|
||||
response.usage.completion_tokens
|
||||
if response.usage
|
||||
else _estimate_tokens([{"role": "assistant", "content": content}])
|
||||
)
|
||||
token_prefix = "" if response.usage else "~"
|
||||
|
||||
# Display performance metrics
|
||||
console.print(
|
||||
f"[dim]Time: {total_time*1000:.0f}ms | "
|
||||
f"In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}[/dim]"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
@@ -335,20 +473,14 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
||||
|
||||
if cmd in ["/quit", "/exit", "/q"]:
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
print_info(t("chat_goodbye"))
|
||||
return False
|
||||
|
||||
elif cmd in ["/help", "/h"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Available Commands:[/bold]\n\n"
|
||||
"/help, /h - Show this help message\n"
|
||||
"/quit, /exit, /q - Exit chat\n"
|
||||
"/clear, /c - Clear conversation history\n"
|
||||
"/history, /hist - Show conversation history\n"
|
||||
"/info, /i - Show current settings\n"
|
||||
"/retry, /r - Regenerate last response",
|
||||
f"[bold]{t('chat_help_title')}[/bold]\n\n{t('chat_help_content')}",
|
||||
title="Help",
|
||||
border_style="cyan",
|
||||
)
|
||||
@@ -359,19 +491,19 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
||||
elif cmd in ["/clear", "/c"]:
|
||||
messages.clear()
|
||||
console.print()
|
||||
print_success("Conversation history cleared")
|
||||
print_success(t("chat_history_cleared"))
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/history", "/hist"]:
|
||||
console.print()
|
||||
if not messages:
|
||||
print_info("No conversation history")
|
||||
print_info(t("chat_no_history"))
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
_format_history(messages),
|
||||
title=f"History ({len(messages)} messages)",
|
||||
title=t("chat_history_title", count=len(messages)),
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
@@ -382,10 +514,7 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Current Settings:[/bold]\n\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan]\n"
|
||||
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
|
||||
f"Messages: [cyan]{len(messages)}[/cyan]",
|
||||
f"[bold]{t('chat_info_title')}[/bold]\n\n{t('chat_info_content', temperature=temperature, max_tokens=max_tokens, messages=len(messages))}",
|
||||
title="Info",
|
||||
border_style="cyan",
|
||||
)
|
||||
@@ -397,16 +526,16 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
||||
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
|
||||
# Remove last assistant response
|
||||
messages.pop()
|
||||
print_info("Retrying last response...")
|
||||
print_info(t("chat_retrying"))
|
||||
console.print()
|
||||
else:
|
||||
print_warning("No previous response to retry")
|
||||
print_warning(t("chat_no_retry"))
|
||||
console.print()
|
||||
return True
|
||||
|
||||
else:
|
||||
print_warning(f"Unknown command: {command}")
|
||||
console.print("[dim]Type /help for available commands[/dim]")
|
||||
print_warning(t("chat_unknown_command", command=command))
|
||||
console.print(f"[dim]{t('chat_unknown_hint')}[/dim]")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,12 +35,12 @@ class QuantMethod(str, Enum):
|
||||
|
||||
|
||||
def quant(
|
||||
model: str = typer.Argument(
|
||||
...,
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or path to quantize",
|
||||
),
|
||||
method: QuantMethod = typer.Option(
|
||||
QuantMethod.INT4,
|
||||
method: Optional[QuantMethod] = typer.Option(
|
||||
None,
|
||||
"--method",
|
||||
"-m",
|
||||
help="Quantization method",
|
||||
@@ -51,8 +51,8 @@ def quant(
|
||||
"-o",
|
||||
help="Output path for quantized weights",
|
||||
),
|
||||
input_type: str = typer.Option(
|
||||
"fp8",
|
||||
input_type: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--input-type",
|
||||
"-i",
|
||||
help="Input weight type (fp8, fp16, bf16)",
|
||||
@@ -72,6 +72,11 @@ def quant(
|
||||
"--no-merge",
|
||||
help="Don't merge safetensor files",
|
||||
),
|
||||
gpu: bool = typer.Option(
|
||||
False,
|
||||
"--gpu",
|
||||
help="Use GPU for conversion (faster)",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
@@ -79,54 +84,231 @@ def quant(
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Quantize model weights for CPU inference."""
|
||||
settings = get_settings()
|
||||
console.print()
|
||||
"""Quantize model weights for CPU inference.
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
If no model is specified, interactive mode will be activated.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Check if we should use interactive mode
|
||||
# Interactive mode triggers when: no model, or missing critical parameters
|
||||
needs_interactive = model is None or method is None or cpu_threads is None or numa_nodes is None
|
||||
is_interactive = False
|
||||
|
||||
if needs_interactive and sys.stdin.isatty():
|
||||
# Use interactive configuration (includes verification in Step 1.5)
|
||||
from kt_kernel.cli.utils.quant_interactive import interactive_quant_config
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold cyan]═══ {t('quant_interactive_title')} ═══[/bold cyan]")
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('quant_new_model_notice')}[/yellow]")
|
||||
console.print()
|
||||
|
||||
config = interactive_quant_config()
|
||||
if config is None:
|
||||
# User cancelled
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Extract configuration
|
||||
model_obj = config["model"]
|
||||
model = model_obj.id
|
||||
input_path = Path(model_obj.path)
|
||||
method = QuantMethod(config["method"])
|
||||
input_type = config["input_type"]
|
||||
cpu_threads = config["cpu_threads"]
|
||||
numa_nodes = config["numa_nodes"]
|
||||
output = config["output_path"]
|
||||
gpu = config["use_gpu"]
|
||||
is_interactive = True
|
||||
|
||||
console.print()
|
||||
print_success(t("quant_config_complete"))
|
||||
console.print()
|
||||
else:
|
||||
# Non-interactive mode - require model parameter
|
||||
if model is None:
|
||||
print_error("Model argument is required in non-interactive mode")
|
||||
console.print()
|
||||
console.print("Usage: kt quant <model>")
|
||||
console.print(" Or: kt quant (for interactive mode)")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Set defaults for optional parameters
|
||||
method = method or QuantMethod.INT4
|
||||
input_type = input_type or "fp8"
|
||||
|
||||
console.print()
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Pre-quantization verification (only in non-interactive mode)
|
||||
# Interactive mode already did verification in interactive_quant_config()
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(str(input_path))
|
||||
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
|
||||
|
||||
# Get user model info for both modes (needed later for registering quantized model)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(str(input_path))
|
||||
|
||||
# Validate that it's a MoE model (not AMX or GGUF)
|
||||
from kt_kernel.cli.commands.model import is_amx_weights
|
||||
|
||||
# Check if it's AMX (already quantized)
|
||||
is_amx, _ = is_amx_weights(str(input_path))
|
||||
if is_amx:
|
||||
print_error("Cannot quantize AMX models (already quantized)")
|
||||
console.print()
|
||||
console.print(f" The model at {input_path} is already in AMX format.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
# Check if it's a MoE model
|
||||
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
|
||||
# Detect CPU configuration
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
print_info(f"CPU threads: {final_cpu_threads}")
|
||||
print_info(f"NUMA nodes: {final_numa_nodes}")
|
||||
|
||||
# Check if output exists
|
||||
if output.exists():
|
||||
print_warning(f"Output path already exists: {output}")
|
||||
moe_result = None # Store for later use when registering quantized model
|
||||
try:
|
||||
moe_result = analyze_moe_model(str(input_path), use_cache=True)
|
||||
if not moe_result or not moe_result.get("is_moe"):
|
||||
print_error("Only MoE models can be quantized to AMX format")
|
||||
console.print()
|
||||
console.print(f" The model at {input_path} is not a MoE model.")
|
||||
console.print(" AMX quantization is designed for MoE models (e.g., DeepSeek-V3).")
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
print_warning(f"Could not detect MoE information: {e}")
|
||||
console.print()
|
||||
if not yes:
|
||||
if not confirm("Overwrite?", default=False):
|
||||
if not confirm("Continue quantization anyway?", default=False):
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Detect CPU configuration and resolve output path (only needed in non-interactive mode)
|
||||
if not is_interactive:
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
|
||||
# Detect CPU configuration (needed for output path)
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
# Priority: paths.weights > paths.models[0] > model's parent directory
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
if weights_dir and weights_dir.exists():
|
||||
# Use configured weights directory (highest priority)
|
||||
output = weights_dir / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
else:
|
||||
# Use first model storage path
|
||||
model_paths = settings.get_model_paths()
|
||||
if model_paths and model_paths[0].exists():
|
||||
output = model_paths[0] / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
else:
|
||||
# Fallback to model's parent directory
|
||||
output = input_path.parent / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
print_info(t("quant_cpu_threads", threads=final_cpu_threads))
|
||||
print_info(t("quant_numa_nodes", nodes=final_numa_nodes))
|
||||
|
||||
# Calculate space requirements
|
||||
console.print()
|
||||
console.print(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
# Calculate source model size
|
||||
try:
|
||||
total_bytes = sum(f.stat().st_size for f in input_path.glob("*.safetensors") if f.is_file())
|
||||
source_size_gb = total_bytes / (1024**3)
|
||||
except Exception:
|
||||
source_size_gb = 0.0
|
||||
|
||||
# Estimate quantized size
|
||||
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
|
||||
quant_bits = {"int4": 4, "int8": 8}
|
||||
input_bit = input_bits.get(input_type, 16)
|
||||
quant_bit = quant_bits.get(method.value, 4)
|
||||
ratio = quant_bit / input_bit
|
||||
estimated_size_gb = source_size_gb * ratio
|
||||
|
||||
# Check available space
|
||||
import shutil
|
||||
|
||||
try:
|
||||
check_path = output.parent if not output.exists() else output
|
||||
while not check_path.exists() and check_path != check_path.parent:
|
||||
check_path = check_path.parent
|
||||
stat = shutil.disk_usage(check_path)
|
||||
available_gb = stat.free / (1024**3)
|
||||
except Exception:
|
||||
available_gb = 0.0
|
||||
|
||||
is_sufficient = available_gb >= (estimated_size_gb * 1.2)
|
||||
|
||||
console.print(f" {t('quant_source_size'):<26} {source_size_gb:.2f} GB")
|
||||
console.print(f" {t('quant_estimated_size'):<26} {estimated_size_gb:.2f} GB")
|
||||
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
|
||||
console.print()
|
||||
|
||||
if not is_sufficient:
|
||||
required_with_buffer = estimated_size_gb * 1.2
|
||||
print_warning(t("quant_insufficient_space"))
|
||||
console.print()
|
||||
console.print(f" {t('quant_required_space'):<26} {required_with_buffer:.2f} GB")
|
||||
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
|
||||
console.print(f" {t('quant_shortage'):<26} {required_with_buffer - available_gb:.2f} GB")
|
||||
console.print()
|
||||
console.print(f" {t('quant_may_fail')}")
|
||||
console.print()
|
||||
|
||||
if not yes:
|
||||
if not confirm(t("quant_continue_anyway"), default=False):
|
||||
raise typer.Abort()
|
||||
console.print()
|
||||
|
||||
# Check if output exists and generate unique name
|
||||
if output.exists():
|
||||
print_warning(t("quant_output_exists", path=str(output)))
|
||||
console.print()
|
||||
|
||||
# Generate unique name by adding suffix
|
||||
original_name = output.name
|
||||
parent_dir = output.parent
|
||||
counter = 2
|
||||
|
||||
while output.exists():
|
||||
new_name = f"{original_name}-{counter}"
|
||||
output = parent_dir / new_name
|
||||
counter += 1
|
||||
|
||||
print_success(t("quant_using_unique", path=str(output)))
|
||||
console.print()
|
||||
|
||||
# Confirm (only show if not using --yes flag)
|
||||
if not yes:
|
||||
console.print()
|
||||
print_warning(t("quant_time_warning"))
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm
|
||||
if not yes:
|
||||
console.print()
|
||||
console.print("[bold]Quantization Settings:[/bold]")
|
||||
console.print(f" Input: {input_path}")
|
||||
console.print(f" Output: {output}")
|
||||
console.print(f" Method: {method.value.upper()}")
|
||||
console.print(f" Input type: {input_type}")
|
||||
console.print()
|
||||
print_warning("Quantization may take 30-60 minutes depending on model size.")
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
else:
|
||||
# Interactive mode: cpu_threads and numa_nodes already set
|
||||
final_cpu_threads = cpu_threads
|
||||
final_numa_nodes = numa_nodes
|
||||
|
||||
# Find conversion script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
@@ -141,37 +323,145 @@ def quant(
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable, str(script_path),
|
||||
"--input-path", str(input_path),
|
||||
"--input-type", input_type,
|
||||
"--output", str(output),
|
||||
"--quant-method", method.value,
|
||||
"--cpuinfer-threads", str(final_cpu_threads),
|
||||
"--threadpool-count", str(final_numa_nodes),
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--input-path",
|
||||
str(input_path),
|
||||
"--input-type",
|
||||
input_type,
|
||||
"--output",
|
||||
str(output),
|
||||
"--quant-method",
|
||||
method.value,
|
||||
"--cpuinfer-threads",
|
||||
str(final_cpu_threads),
|
||||
"--threadpool-count",
|
||||
str(final_numa_nodes),
|
||||
]
|
||||
|
||||
if no_merge:
|
||||
cmd.append("--no-merge-safetensor")
|
||||
|
||||
if gpu:
|
||||
cmd.append("--gpu")
|
||||
|
||||
# Run quantization
|
||||
console.print()
|
||||
print_step(t("quant_starting"))
|
||||
console.print()
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
console.print("[dim]" + "=" * 80 + "[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
# Run with real-time stdout/stderr output
|
||||
import os
|
||||
import time
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1" # Disable Python output buffering
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
stdout=None, # Inherit parent's stdout (real-time output)
|
||||
stderr=None, # Inherit parent's stderr (real-time output)
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = time.time() - start_time
|
||||
hours = int(elapsed_time // 3600)
|
||||
minutes = int((elapsed_time % 3600) // 60)
|
||||
seconds = int(elapsed_time % 60)
|
||||
|
||||
console.print()
|
||||
console.print("[dim]" + "=" * 80 + "[/dim]")
|
||||
console.print()
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("quant_complete"))
|
||||
console.print()
|
||||
|
||||
# Display elapsed time
|
||||
if hours > 0:
|
||||
time_str = f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes > 0:
|
||||
time_str = f"{minutes}m {seconds}s"
|
||||
else:
|
||||
time_str = f"{seconds}s"
|
||||
console.print(f" [cyan]{t('quant_time_elapsed')} {time_str}[/cyan]")
|
||||
console.print()
|
||||
console.print(f" Quantized weights saved to: {output}")
|
||||
console.print()
|
||||
console.print(" Use with:")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
|
||||
# Auto-register the quantized model
|
||||
try:
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModel
|
||||
|
||||
# Generate model name from output path
|
||||
base_name = output.name
|
||||
suggested_name = user_registry.suggest_name(base_name)
|
||||
|
||||
# Determine MoE information and source model name
|
||||
if user_model_obj:
|
||||
is_moe_val = user_model_obj.is_moe
|
||||
num_experts = user_model_obj.moe_num_experts
|
||||
num_active = user_model_obj.moe_num_experts_per_tok
|
||||
repo_type_val = user_model_obj.repo_type
|
||||
repo_id_val = user_model_obj.repo_id
|
||||
source_model_name = user_model_obj.name # Store source model name
|
||||
elif moe_result:
|
||||
is_moe_val = moe_result.get("is_moe", True)
|
||||
num_experts = moe_result.get("num_experts")
|
||||
num_active = moe_result.get("num_experts_per_tok")
|
||||
repo_type_val = None
|
||||
repo_id_val = None
|
||||
source_model_name = input_path.name # Use folder name as fallback
|
||||
else:
|
||||
is_moe_val = None
|
||||
num_experts = None
|
||||
num_active = None
|
||||
repo_type_val = None
|
||||
repo_id_val = None
|
||||
source_model_name = input_path.name # Use folder name as fallback
|
||||
|
||||
# Create new model entry (AMX format uses "safetensors" format, detected by is_amx_weights())
|
||||
new_model = UserModel(
|
||||
name=suggested_name,
|
||||
path=str(output),
|
||||
format="safetensors", # AMX files are safetensors format
|
||||
repo_type=repo_type_val,
|
||||
repo_id=repo_id_val,
|
||||
sha256_status="not_checked", # AMX weights don't need verification
|
||||
# Inherit MoE information from source model
|
||||
is_moe=is_moe_val,
|
||||
moe_num_experts=num_experts,
|
||||
moe_num_experts_per_tok=num_active,
|
||||
# AMX quantization metadata
|
||||
amx_source_model=source_model_name,
|
||||
amx_quant_method=method.value, # "int4" or "int8"
|
||||
amx_numa_nodes=final_numa_nodes,
|
||||
)
|
||||
|
||||
user_registry.add_model(new_model)
|
||||
console.print()
|
||||
print_success(t("quant_registered", name=suggested_name))
|
||||
console.print()
|
||||
console.print(f" {t('quant_view_with')} [cyan]kt model list[/cyan]")
|
||||
console.print(f" {t('quant_use_with')} [cyan]kt run {suggested_name}[/cyan]")
|
||||
console.print()
|
||||
except Exception as e:
|
||||
# Non-fatal error - quantization succeeded but registration failed
|
||||
console.print()
|
||||
print_warning(t("quant_register_failed", error=str(e)))
|
||||
console.print()
|
||||
console.print(f" {t('quant_use_with')}")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
else:
|
||||
print_error(f"Quantization failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
@@ -221,6 +511,7 @@ def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -28,7 +28,7 @@ from kt_kernel.cli.utils.console import (
|
||||
prompt_choice,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
|
||||
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
|
||||
@click.command(
|
||||
@@ -120,8 +120,6 @@ def run(
|
||||
# Handle disable/enable shared experts fusion flags
|
||||
if enable_shared_experts_fusion:
|
||||
disable_shared_experts_fusion = False
|
||||
elif disable_shared_experts_fusion is None:
|
||||
disable_shared_experts_fusion = None
|
||||
|
||||
# Convert Path objects from click
|
||||
model_path_obj = Path(model_path) if model_path else None
|
||||
@@ -214,266 +212,250 @@ def _run_impl(
|
||||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
user_registry = UserModelRegistry()
|
||||
|
||||
console.print()
|
||||
# Check if we should use interactive mode
|
||||
# Interactive mode triggers when:
|
||||
# 1. No model specified, OR
|
||||
# 2. Model specified but missing critical parameters (gpu_experts, tensor_parallel_size, etc.)
|
||||
use_interactive = False
|
||||
|
||||
# If no model specified, show interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(registry, settings)
|
||||
if model is None:
|
||||
use_interactive = True
|
||||
elif (
|
||||
gpu_experts is None
|
||||
or tensor_parallel_size is None
|
||||
or cpu_threads is None
|
||||
or numa_nodes is None
|
||||
or max_total_tokens is None
|
||||
):
|
||||
# Model specified but some parameters missing - use interactive
|
||||
use_interactive = True
|
||||
|
||||
if use_interactive and sys.stdin.isatty():
|
||||
# Use new interactive configuration flow
|
||||
from kt_kernel.cli.utils.run_interactive import interactive_run_config
|
||||
|
||||
console.print()
|
||||
console.print("[bold cyan]═══ Interactive Run Configuration ═══[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
config = interactive_run_config()
|
||||
if config is None:
|
||||
# User cancelled
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Step 1: Detect hardware
|
||||
print_step(t("run_detecting_hardware"))
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
# Extract configuration from new format
|
||||
user_model_obj = config["model"]
|
||||
model = user_model_obj.id
|
||||
resolved_model_path = Path(config["model_path"])
|
||||
resolved_weights_path = Path(config["weights_path"])
|
||||
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
# Extract parameters
|
||||
gpu_experts = config["gpu_experts"]
|
||||
cpu_threads = config["cpu_threads"]
|
||||
numa_nodes = config["numa_nodes"]
|
||||
tensor_parallel_size = config["tp_size"]
|
||||
|
||||
# Get kt-method and other method-specific settings
|
||||
kt_method = config["kt_method"]
|
||||
|
||||
# KV cache settings (may be None for non-raw methods)
|
||||
max_total_tokens = config.get("kv_cache", 32768)
|
||||
chunked_prefill_size = config.get("chunk_prefill", 32768)
|
||||
kt_gpu_prefill_threshold = config.get("gpu_prefill_threshold", 500)
|
||||
|
||||
# Memory settings
|
||||
mem_fraction_static = config["mem_fraction_static"]
|
||||
|
||||
# Parser settings (optional)
|
||||
tool_call_parser = config.get("tool_call_parser")
|
||||
reasoning_parser = config.get("reasoning_parser")
|
||||
|
||||
# Server settings
|
||||
host = config.get("host", "0.0.0.0")
|
||||
port = config.get("port", 30000)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for selected GPUs
|
||||
selected_gpus = config["selected_gpus"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id) for gpu_id in selected_gpus)
|
||||
|
||||
# Detect hardware for parameter resolution (needed for resolve() function later)
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
|
||||
console.print()
|
||||
print_info(f"[green]✓[/green] Configuration complete")
|
||||
console.print()
|
||||
else:
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
# Non-interactive mode - use traditional flow
|
||||
console.print()
|
||||
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
# Initialize variables that may have been set by interactive mode
|
||||
# These will be None in non-interactive mode and will use defaults via resolve()
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
# If no model specified, show old interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(user_registry, settings)
|
||||
if model is None:
|
||||
raise typer.Exit(0)
|
||||
|
||||
model_info = None
|
||||
resolved_model_path = model_path
|
||||
# Detect hardware (needed for defaults)
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to infer model type from path to use default configurations
|
||||
# Check directory name against known models
|
||||
dir_name = resolved_model_path.name.lower()
|
||||
for registered_model in registry.list_all():
|
||||
# Check if directory name matches model name or aliases
|
||||
if dir_name == registered_model.name.lower():
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
for alias in registered_model.aliases:
|
||||
if dir_name == alias.lower() or alias.lower() in dir_name:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
if model_info:
|
||||
break
|
||||
|
||||
# Also check HuggingFace repo format (org--model)
|
||||
if not model_info:
|
||||
for registered_model in registry.list_all():
|
||||
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
|
||||
if repo_slug in dir_name or dir_name in repo_slug:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
print_warning("Could not detect model type from path. Using default parameters.")
|
||||
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
|
||||
else:
|
||||
# Search in registry
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print("Available models:")
|
||||
for m in registry.list_all()[:5]:
|
||||
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
else:
|
||||
# Multiple matches - prompt user
|
||||
console.print()
|
||||
print_info(t("run_multiple_matches"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("run_select_model"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
|
||||
# Find model path
|
||||
if model_path is None:
|
||||
resolved_model_path = _find_model_path(model_info, settings)
|
||||
if resolved_model_path is None:
|
||||
print_error(t("run_model_not_found", name=model_info.name))
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
|
||||
user_model_obj = None
|
||||
resolved_model_path = model_path
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to find in user registry by path
|
||||
user_model_obj = user_registry.find_by_path(str(resolved_model_path))
|
||||
if user_model_obj:
|
||||
print_info(f"Using registered model: {user_model_obj.name}")
|
||||
else:
|
||||
print_warning("Using unregistered model path. Consider adding it with 'kt model add'")
|
||||
else:
|
||||
# Search in user registry by name
|
||||
user_model_obj = user_registry.get_model(model)
|
||||
|
||||
if not user_model_obj:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print(
|
||||
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
|
||||
)
|
||||
|
||||
# Show available models
|
||||
all_models = user_registry.list_models()
|
||||
if all_models:
|
||||
console.print("Available registered models:")
|
||||
for m in all_models[:5]:
|
||||
console.print(f" - {m.name}")
|
||||
if len(all_models) > 5:
|
||||
console.print(f" ... and {len(all_models) - 5} more")
|
||||
else:
|
||||
console.print("No models registered yet.")
|
||||
|
||||
console.print()
|
||||
console.print(f"Add your model with: [cyan]kt model add /path/to/model[/cyan]")
|
||||
console.print(f"Or scan for models: [cyan]kt model scan[/cyan]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
# Use model path from registry
|
||||
resolved_model_path = Path(user_model_obj.path)
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
# Verify path exists
|
||||
if not resolved_model_path.exists():
|
||||
print_error(f"Model path does not exist: {resolved_model_path}")
|
||||
console.print()
|
||||
console.print(f"Run 'kt model refresh' to check all models")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Step 2.5: Pre-run verification (optional integrity check)
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="running")
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
|
||||
# Step 4: Build command
|
||||
# Resolve all parameters (CLI > model defaults > config > auto-detect)
|
||||
final_host = host or settings.get("server.host", "0.0.0.0")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
# Helper to resolve parameter with fallback chain: CLI > config > default
|
||||
def resolve(cli_val, config_key, default):
|
||||
if cli_val is not None:
|
||||
return cli_val
|
||||
config_val = settings.get(config_key)
|
||||
return config_val if config_val is not None else default
|
||||
|
||||
# Get defaults from model info if available
|
||||
model_defaults = model_info.default_params if model_info else {}
|
||||
# Server configuration
|
||||
final_host = resolve(host, "server.host", "0.0.0.0")
|
||||
final_port = resolve(port, "server.port", 30000)
|
||||
|
||||
# Determine tensor parallel size first (needed for GPU expert calculation)
|
||||
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
|
||||
|
||||
# Check if explicitly specified by user or configuration
|
||||
explicitly_specified = (
|
||||
tensor_parallel_size # CLI argument (highest priority)
|
||||
or model_defaults.get("tensor-parallel-size") # Model defaults
|
||||
or settings.get("inference.tensor_parallel_size") # Config file
|
||||
# Tensor parallel size: CLI > config > auto-detect from GPUs
|
||||
final_tensor_parallel_size = resolve(
|
||||
tensor_parallel_size, "inference.tensor_parallel_size", len(gpus) if gpus else 1
|
||||
)
|
||||
|
||||
if explicitly_specified:
|
||||
# Use explicitly specified value
|
||||
requested_tensor_parallel_size = explicitly_specified
|
||||
else:
|
||||
# Auto-detect from GPUs, considering model's max constraint
|
||||
detected_gpu_count = len(gpus) if gpus else 1
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
# Automatically limit to model's maximum to use as many GPUs as possible
|
||||
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
|
||||
else:
|
||||
requested_tensor_parallel_size = detected_gpu_count
|
||||
|
||||
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
|
||||
final_tensor_parallel_size = requested_tensor_parallel_size
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
|
||||
console.print()
|
||||
print_warning(
|
||||
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
|
||||
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
|
||||
f"Reducing to {model_info.max_tensor_parallel_size}."
|
||||
)
|
||||
final_tensor_parallel_size = model_info.max_tensor_parallel_size
|
||||
|
||||
# CPU/GPU configuration with smart defaults
|
||||
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
|
||||
total_threads = cpu.cores * cpu.numa_nodes
|
||||
final_cpu_threads = (
|
||||
cpu_threads
|
||||
or model_defaults.get("kt-cpuinfer")
|
||||
or settings.get("inference.cpu_threads")
|
||||
or int(total_threads * 0.8)
|
||||
)
|
||||
|
||||
# kt-threadpool-count: default to NUMA node count
|
||||
final_numa_nodes = (
|
||||
numa_nodes
|
||||
or model_defaults.get("kt-threadpool-count")
|
||||
or settings.get("inference.numa_nodes")
|
||||
or cpu.numa_nodes
|
||||
)
|
||||
|
||||
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
|
||||
if gpu_experts is not None:
|
||||
# User explicitly set it
|
||||
final_gpu_experts = gpu_experts
|
||||
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
|
||||
# Use model-specific computation function (only if GPUs detected)
|
||||
vram_per_gpu = gpus[0].vram_gb
|
||||
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
|
||||
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
|
||||
console.print()
|
||||
print_info(
|
||||
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
|
||||
)
|
||||
else:
|
||||
# Fall back to defaults
|
||||
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
|
||||
total_threads = cpu.threads # Use logical threads instead of physical cores
|
||||
final_cpu_threads = resolve(cpu_threads, "inference.cpu_threads", int(total_threads * 0.8))
|
||||
final_numa_nodes = resolve(numa_nodes, "inference.numa_nodes", cpu.numa_nodes)
|
||||
final_gpu_experts = resolve(gpu_experts, "inference.gpu_experts", 1)
|
||||
|
||||
# KT-kernel options
|
||||
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = (
|
||||
kt_gpu_prefill_threshold
|
||||
or model_defaults.get("kt-gpu-prefill-token-threshold")
|
||||
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
)
|
||||
final_kt_method = resolve(kt_method, "inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = resolve(kt_gpu_prefill_threshold, "inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
|
||||
# SGLang options
|
||||
final_attention_backend = (
|
||||
attention_backend
|
||||
or model_defaults.get("attention-backend")
|
||||
or settings.get("inference.attention_backend", "triton")
|
||||
)
|
||||
final_max_total_tokens = (
|
||||
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
|
||||
)
|
||||
final_max_running_requests = (
|
||||
max_running_requests
|
||||
or model_defaults.get("max-running-requests")
|
||||
or settings.get("inference.max_running_requests", 32)
|
||||
)
|
||||
final_chunked_prefill_size = (
|
||||
chunked_prefill_size
|
||||
or model_defaults.get("chunked-prefill-size")
|
||||
or settings.get("inference.chunked_prefill_size", 4096)
|
||||
)
|
||||
final_mem_fraction_static = (
|
||||
mem_fraction_static
|
||||
or model_defaults.get("mem-fraction-static")
|
||||
or settings.get("inference.mem_fraction_static", 0.98)
|
||||
)
|
||||
final_watchdog_timeout = (
|
||||
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
|
||||
)
|
||||
final_served_model_name = (
|
||||
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
|
||||
)
|
||||
final_attention_backend = resolve(attention_backend, "inference.attention_backend", "flashinfer")
|
||||
final_max_total_tokens = resolve(max_total_tokens, "inference.max_total_tokens", 40000)
|
||||
final_max_running_requests = resolve(max_running_requests, "inference.max_running_requests", 32)
|
||||
final_chunked_prefill_size = resolve(chunked_prefill_size, "inference.chunked_prefill_size", 4096)
|
||||
final_mem_fraction_static = resolve(mem_fraction_static, "inference.mem_fraction_static", 0.98)
|
||||
final_watchdog_timeout = resolve(watchdog_timeout, "inference.watchdog_timeout", 3000)
|
||||
final_served_model_name = resolve(served_model_name, "inference.served_model_name", "")
|
||||
|
||||
# Performance flags
|
||||
if disable_shared_experts_fusion is not None:
|
||||
final_disable_shared_experts_fusion = disable_shared_experts_fusion
|
||||
elif "disable-shared-experts-fusion" in model_defaults:
|
||||
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
|
||||
else:
|
||||
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
|
||||
final_disable_shared_experts_fusion = resolve(
|
||||
disable_shared_experts_fusion, "inference.disable_shared_experts_fusion", True
|
||||
)
|
||||
|
||||
# Pass all model default params to handle any extra parameters
|
||||
extra_params = model_defaults if model_info else {}
|
||||
# Pass extra CLI parameters
|
||||
extra_params = {}
|
||||
|
||||
# Parser parameters (from interactive mode or None in non-interactive mode)
|
||||
final_tool_call_parser = None
|
||||
final_reasoning_parser = None
|
||||
if "tool_call_parser" in locals() and tool_call_parser:
|
||||
final_tool_call_parser = tool_call_parser
|
||||
if "reasoning_parser" in locals() and reasoning_parser:
|
||||
final_reasoning_parser = reasoning_parser
|
||||
|
||||
cmd = _build_sglang_command(
|
||||
model_path=resolved_model_path,
|
||||
weights_path=resolved_weights_path,
|
||||
model_info=model_info,
|
||||
host=final_host,
|
||||
port=final_port,
|
||||
gpu_experts=final_gpu_experts,
|
||||
@@ -490,6 +472,8 @@ def _run_impl(
|
||||
watchdog_timeout=final_watchdog_timeout,
|
||||
served_model_name=final_served_model_name,
|
||||
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
|
||||
tool_call_parser=final_tool_call_parser,
|
||||
reasoning_parser=final_reasoning_parser,
|
||||
settings=settings,
|
||||
extra_model_params=extra_params,
|
||||
extra_cli_args=extra_cli_args,
|
||||
@@ -508,11 +492,9 @@ def _run_impl(
|
||||
console.print()
|
||||
print_step("Configuration")
|
||||
|
||||
# Model info
|
||||
if model_info:
|
||||
console.print(f" Model: [bold]{model_info.name}[/bold]")
|
||||
else:
|
||||
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
|
||||
# Display model name
|
||||
model_display_name = user_model_obj.name if user_model_obj else resolved_model_path.name
|
||||
console.print(f" Model: [bold]{model_display_name}[/bold]")
|
||||
|
||||
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
|
||||
|
||||
@@ -572,88 +554,13 @@ def _run_impl(
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_model_path(model_info: ModelInfo, settings, max_depth: int = 3) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths.
|
||||
|
||||
Args:
|
||||
model_info: Model information to search for
|
||||
settings: Settings instance
|
||||
max_depth: Maximum depth to search within each model path (default: 3)
|
||||
|
||||
Returns:
|
||||
Path to the model directory, or None if not found
|
||||
"""
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Generate possible names to search for
|
||||
possible_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.name.replace(" ", "-"),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
# Add alias-based names
|
||||
for alias in model_info.aliases:
|
||||
possible_names.append(alias)
|
||||
possible_names.append(alias.lower())
|
||||
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Search recursively up to max_depth
|
||||
for depth in range(max_depth):
|
||||
for name in possible_names:
|
||||
if depth == 0:
|
||||
# Direct children: models_dir / name
|
||||
search_paths = [models_dir / name]
|
||||
else:
|
||||
# Nested: use rglob to find directories matching the name
|
||||
search_paths = list(models_dir.rglob(name))
|
||||
|
||||
for path in search_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the quantized weights path on disk by searching all configured paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
# Check common patterns
|
||||
base_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
|
||||
|
||||
# Prepare search directories
|
||||
search_dirs = [weights_dir] if weights_dir else []
|
||||
search_dirs.extend(model_paths)
|
||||
|
||||
for base in base_names:
|
||||
for suffix in suffixes:
|
||||
for dir_path in search_dirs:
|
||||
if dir_path:
|
||||
path = dir_path / f"{base}{suffix}"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
# Dead code removed: _find_model_path() and _find_weights_path()
|
||||
# These functions were part of the old builtin model system
|
||||
|
||||
|
||||
def _build_sglang_command(
|
||||
model_path: Path,
|
||||
weights_path: Optional[Path],
|
||||
model_info: Optional[ModelInfo],
|
||||
host: str,
|
||||
port: int,
|
||||
gpu_experts: int,
|
||||
@@ -670,6 +577,8 @@ def _build_sglang_command(
|
||||
watchdog_timeout: int,
|
||||
served_model_name: str,
|
||||
disable_shared_experts_fusion: bool,
|
||||
tool_call_parser: Optional[str],
|
||||
reasoning_parser: Optional[str],
|
||||
settings,
|
||||
extra_model_params: Optional[dict] = None, # New parameter for additional params
|
||||
extra_cli_args: Optional[list[str]] = None, # Extra args from CLI to pass to sglang
|
||||
@@ -700,9 +609,6 @@ def _build_sglang_command(
|
||||
elif cpu_threads > 0 or gpu_experts > 1:
|
||||
# CPU offloading configured - use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif model_info and model_info.type == "moe":
|
||||
# MoE model - likely needs kt-kernel for expert offloading
|
||||
use_kt_kernel = True
|
||||
|
||||
if use_kt_kernel:
|
||||
# Add kt-weight-path: use quantized weights if available, otherwise use model path
|
||||
@@ -723,6 +629,7 @@ def _build_sglang_command(
|
||||
kt_method,
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(kt_gpu_prefill_threshold),
|
||||
"--kt-enable-dynamic-expert-update", # Enable dynamic expert updates
|
||||
]
|
||||
)
|
||||
|
||||
@@ -757,6 +664,16 @@ def _build_sglang_command(
|
||||
if disable_shared_experts_fusion:
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add FP8 backend if using FP8 method
|
||||
if "FP8" in kt_method.upper():
|
||||
cmd.extend(["--fp8-gemm-backend", "triton"])
|
||||
|
||||
# Add parsers if specified
|
||||
if tool_call_parser:
|
||||
cmd.extend(["--tool-call-parser", tool_call_parser])
|
||||
if reasoning_parser:
|
||||
cmd.extend(["--reasoning-parser", reasoning_parser])
|
||||
|
||||
# Add any extra parameters from model defaults that weren't explicitly handled
|
||||
if extra_model_params:
|
||||
# List of parameters already handled above
|
||||
@@ -801,30 +718,31 @@ def _build_sglang_command(
|
||||
return cmd
|
||||
|
||||
|
||||
def _interactive_model_selection(registry, settings) -> Optional[str]:
|
||||
def _interactive_model_selection(user_registry, settings) -> Optional[str]:
|
||||
"""Show interactive model selection interface.
|
||||
|
||||
Returns:
|
||||
Selected model name or None if cancelled.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
# Get all user models
|
||||
all_models = user_registry.list_models()
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Find local models first
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
# Get all registered models
|
||||
all_models = registry.list_all()
|
||||
if not all_models:
|
||||
console.print()
|
||||
print_warning("No models registered.")
|
||||
console.print()
|
||||
console.print(f" Add models with: [cyan]kt model scan[/cyan]")
|
||||
console.print(f" Or manually: [cyan]kt model add /path/to/model[/cyan]")
|
||||
console.print()
|
||||
return None
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
t("run_select_model_title"),
|
||||
"Select a model to run",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
@@ -834,54 +752,30 @@ def _interactive_model_selection(registry, settings) -> Optional[str]:
|
||||
choices = []
|
||||
choice_map = {} # index -> model name
|
||||
|
||||
# Section 1: Local models (downloaded)
|
||||
if local_models:
|
||||
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
for i, (model_info, path) in enumerate(local_models, 1):
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Section 2: All registered models (for reference)
|
||||
start_idx = len(local_models) + 1
|
||||
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
|
||||
# Show all user models
|
||||
console.print(f"[bold green]Available Models:[/bold green]")
|
||||
console.print()
|
||||
|
||||
# Filter out already shown local models
|
||||
local_model_names = {m.name for m, _ in local_models}
|
||||
|
||||
for i, model_info in enumerate(all_models, start_idx):
|
||||
if model_info.name in local_model_names:
|
||||
continue
|
||||
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{model_info.hf_repo}[/dim]")
|
||||
for i, model in enumerate(all_models, 1):
|
||||
# Check if path exists
|
||||
path_status = "✓" if model.path_exists() else "✗ Missing"
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model.name}[/bold] [{path_status}]")
|
||||
console.print(f" [dim]{model.format} - {model.path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
choice_map[str(i)] = model.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Add cancel option
|
||||
cancel_idx = str(len(choices) + 1)
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]Cancel[/dim]")
|
||||
choices.append(cancel_idx)
|
||||
console.print()
|
||||
|
||||
# Prompt for selection
|
||||
try:
|
||||
selection = Prompt.ask(
|
||||
t("run_select_model_prompt"),
|
||||
"Select model",
|
||||
choices=choices,
|
||||
default="1" if choices else cancel_idx,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ _kt_completion() {
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
|
||||
# Main commands
|
||||
local commands="version run chat quant bench microbench doctor model config sft"
|
||||
local commands="version run chat quant edit bench microbench doctor model config sft"
|
||||
|
||||
# Global options
|
||||
local global_opts="--help --version"
|
||||
@@ -36,6 +36,10 @@ _kt_completion() {
|
||||
local quant_opts="--method --output --help"
|
||||
COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) )
|
||||
;;
|
||||
edit)
|
||||
local edit_opts="--help"
|
||||
COMPREPLY=( $(compgen -W "${edit_opts}" -- ${cur}) )
|
||||
;;
|
||||
bench|microbench)
|
||||
local bench_opts="--model --config --help"
|
||||
COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) )
|
||||
|
||||
@@ -190,6 +190,70 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"quant_progress": "Quantizing...",
|
||||
"quant_complete": "Quantization complete!",
|
||||
"quant_input_not_found": "Input model not found at {path}",
|
||||
"quant_cpu_threads": "CPU threads: {threads}",
|
||||
"quant_numa_nodes": "NUMA nodes: {nodes}",
|
||||
"quant_time_warning": "Quantization may take 30-60 minutes depending on model size.",
|
||||
"quant_disk_analysis": "Disk Space Analysis:",
|
||||
"quant_source_size": "Source model size:",
|
||||
"quant_estimated_size": "Estimated output size:",
|
||||
"quant_available_space": "Available space:",
|
||||
"quant_insufficient_space": "WARNING: Insufficient disk space!",
|
||||
"quant_required_space": "Required space (with 20% buffer):",
|
||||
"quant_shortage": "Shortage:",
|
||||
"quant_may_fail": "Quantization may fail or produce incomplete files.",
|
||||
"quant_continue_anyway": "Continue anyway?",
|
||||
"quant_settings": "Quantization Settings:",
|
||||
"quant_registered": "Quantized model registered: {name}",
|
||||
"quant_view_with": "View with:",
|
||||
"quant_use_with": "Use with:",
|
||||
"quant_register_failed": "Failed to auto-register model: {error}",
|
||||
"quant_output_exists": "Output path already exists: {path}",
|
||||
"quant_using_unique": "Using unique name: {path}",
|
||||
# Interactive quant
|
||||
"quant_interactive_title": "Interactive Quantization Configuration",
|
||||
"quant_new_model_notice": "⚠ Note: Some newer models cannot be quantized yet (conversion script not adapted). Recommended to use the original precision for inference (no weight conversion needed).",
|
||||
"quant_no_moe_models": "No MoE models found for quantization.",
|
||||
"quant_only_moe": "Only MoE models (e.g., DeepSeek-V3) can be quantized to AMX format.",
|
||||
"quant_add_models": "Add models with: {command}",
|
||||
"quant_moe_available": "MoE Models Available for Quantization:",
|
||||
"quant_select_model": "Select model to quantize",
|
||||
"quant_invalid_choice": "Invalid choice",
|
||||
"quant_step2_method": "Step 2: Quantization Method",
|
||||
"quant_method_label": "Quantization Method:",
|
||||
"quant_int4_desc": "INT4",
|
||||
"quant_int8_desc": "INT8",
|
||||
"quant_select_method": "Select quantization method",
|
||||
"quant_input_type_label": "Input Weight Type:",
|
||||
"quant_fp8_desc": "FP8 (for 8-bit float weights)",
|
||||
"quant_fp16_desc": "FP16 (for 16-bit float weights)",
|
||||
"quant_bf16_desc": "BF16 (for Brain Float 16 weights)",
|
||||
"quant_select_input_type": "Select input type",
|
||||
"quant_step3_cpu": "Step 3: CPU Configuration",
|
||||
"quant_cpu_threads_prompt": "CPU Threads (1 to {max})",
|
||||
"quant_numa_nodes_prompt": "NUMA Nodes (1 to {max})",
|
||||
"quant_use_gpu_label": "Use GPU for conversion?",
|
||||
"quant_gpu_speedup": "GPU can significantly speed up the quantization process",
|
||||
"quant_enable_gpu": "Enable GPU acceleration?",
|
||||
"quant_step4_output": "Step 4: Output Path",
|
||||
"quant_default_path": "Default:",
|
||||
"quant_use_default": "Use default output path?",
|
||||
"quant_custom_path": "Enter custom output path",
|
||||
"quant_output_exists_warn": "⚠ Output path already exists: {path}",
|
||||
"quant_using_unique_name": "→ Using unique name: {path}",
|
||||
"quant_config_summary": "Configuration Summary",
|
||||
"quant_summary_model": "Model:",
|
||||
"quant_summary_method": "Method:",
|
||||
"quant_summary_input_type": "Input Type:",
|
||||
"quant_summary_cpu_threads": "CPU Threads:",
|
||||
"quant_summary_numa": "NUMA Nodes:",
|
||||
"quant_summary_gpu": "Use GPU:",
|
||||
"quant_summary_output": "Output Path:",
|
||||
"quant_start_question": "Start quantization?",
|
||||
"quant_cancelled": "Cancelled",
|
||||
"quant_config_complete": "Configuration complete",
|
||||
"quant_time_elapsed": "Time elapsed:",
|
||||
"yes": "Yes",
|
||||
"no": "No",
|
||||
# SFT command
|
||||
"sft_mode_train": "Training mode",
|
||||
"sft_mode_chat": "Chat mode",
|
||||
@@ -247,6 +311,113 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"chat_proxy_detected": "Proxy detected in environment",
|
||||
"chat_proxy_confirm": "Use proxy for connection?",
|
||||
"chat_proxy_disabled": "Proxy disabled for this session",
|
||||
"chat_openai_required": "OpenAI Python SDK is required for chat functionality.",
|
||||
"chat_install_hint": "Install it with:",
|
||||
"chat_title": "KTransformers Chat",
|
||||
"chat_server": "Server",
|
||||
"chat_temperature": "Temperature",
|
||||
"chat_max_tokens": "Max tokens",
|
||||
"chat_help_hint": "Type '/help' for commands, '/quit' to exit",
|
||||
"chat_connecting": "Connecting to server...",
|
||||
"chat_no_models": "No models available on server",
|
||||
"chat_model_not_found": "Model '{model}' not found. Available models: {available}",
|
||||
"chat_connected": "Connected to model: {model}",
|
||||
"chat_connect_failed": "Failed to connect to server: {error}",
|
||||
"chat_server_not_running": "Make sure the model server is running:",
|
||||
"chat_user_prompt": "You",
|
||||
"chat_assistant_prompt": "Assistant",
|
||||
"chat_generation_error": "Error generating response: {error}",
|
||||
"chat_interrupted": "Chat interrupted. Goodbye!",
|
||||
"chat_history_saved": "History saved to: {path}",
|
||||
"chat_goodbye": "Goodbye!",
|
||||
"chat_help_title": "Available Commands:",
|
||||
"chat_help_content": "/help, /h - Show this help message\n/quit, /exit, /q - Exit chat\n/clear, /c - Clear conversation history\n/history, /hist - Show conversation history\n/info, /i - Show current settings\n/retry, /r - Regenerate last response",
|
||||
"chat_history_cleared": "Conversation history cleared",
|
||||
"chat_no_history": "No conversation history",
|
||||
"chat_history_title": "History ({count} messages)",
|
||||
"chat_info_title": "Current Settings:",
|
||||
"chat_info_content": "Temperature: {temperature}\nMax tokens: {max_tokens}\nMessages: {messages}",
|
||||
"chat_retrying": "Retrying last response...",
|
||||
"chat_no_retry": "No previous response to retry",
|
||||
"chat_unknown_command": "Unknown command: {command}",
|
||||
"chat_unknown_hint": "Type /help for available commands",
|
||||
# Run Interactive
|
||||
"run_int_no_moe_models": "No MoE GPU models found.",
|
||||
"run_int_add_models": "Add models with: kt model scan",
|
||||
"run_int_list_all": "List all models: kt model list --all",
|
||||
"run_int_step1_title": "Step 1: Select Model (GPU MoE Models)",
|
||||
"run_int_select_model": "Select model",
|
||||
"run_int_step2_title": "Step 2: Select Inference Method",
|
||||
"run_int_method_raw": "RAW Precision (FP8/FP8_PERCHANNEL/BF16/RAWINT4)",
|
||||
"run_int_method_amx": "AMX Quantization (INT4/INT8)",
|
||||
"run_int_method_gguf": "GGUF (Llamafile)",
|
||||
"run_int_method_saved": "Use Saved Configuration",
|
||||
"run_int_select_method": "Select inference method",
|
||||
"run_int_raw_precision": "RAW Precision:",
|
||||
"run_int_select_precision": "Select precision",
|
||||
"run_int_amx_method": "AMX Method:",
|
||||
"run_int_select_amx": "Select AMX method",
|
||||
"run_int_step3_title": "Step 3: NUMA and CPU Configuration",
|
||||
"run_int_numa_nodes": "NUMA Nodes (1-{max})",
|
||||
"run_int_cpu_threads": "CPU Threads per NUMA (1-{max})",
|
||||
"run_int_amx_warning": "⚠ Warning: AMX INT4/INT8 requires compatible CPU. Check with: kt doctor",
|
||||
"run_int_step4_title": "Step 4: GPU Experts Configuration",
|
||||
"run_int_gpu_experts": "GPU Experts per Layer (0-{max})",
|
||||
"run_int_gpu_experts_info": "Total experts: {total}, Activated per token: {active}",
|
||||
"run_int_step5_title": "Step 5: KV Cache Configuration",
|
||||
"run_int_kv_cache_size": "KV Cache Size (tokens)",
|
||||
"run_int_chunk_prefill": "Enable Chunk Prefill?",
|
||||
"run_int_chunk_size": "Chunk Prefill Size (tokens)",
|
||||
"run_int_gpu_prefill_threshold": "GPU Prefill Threshold (tokens)",
|
||||
"run_int_step6_title": "Step 6: GPU Selection and Tensor Parallelism",
|
||||
"run_int_available_gpus": "Available GPUs:",
|
||||
"run_int_gpu_id": "GPU {id}",
|
||||
"run_int_vram_info": "{name} ({total:.1f}GB total, {free:.1f}GB free)",
|
||||
"run_int_select_gpus": "Select GPU IDs (comma-separated)",
|
||||
"run_int_invalid_gpu_range": "All GPU IDs must be between 0 and {max}",
|
||||
"run_int_tp_size": "TP Size (must be power of 2: 1,2,4,8...)",
|
||||
"run_int_tp_mismatch": "TP size must match number of selected GPUs ({count})",
|
||||
"run_int_tp_not_power_of_2": "TP size must be a power of 2",
|
||||
"run_int_mem_fraction": "Static Memory Fraction (0.0-1.0)",
|
||||
"run_int_using_saved_mem": "Using saved memory fraction: {fraction}",
|
||||
"run_int_step7_title": "Step 7: Parser Configuration (Optional)",
|
||||
"run_int_tool_call_parser": "Tool Call Parser (press Enter to skip)",
|
||||
"run_int_reasoning_parser": "Reasoning Parser (press Enter to skip)",
|
||||
"run_int_step8_title": "Step 8: Host and Port Configuration",
|
||||
"run_int_host": "Host",
|
||||
"run_int_port": "Port",
|
||||
"run_int_port_occupied": "⚠ Port {port} is already in use",
|
||||
"run_int_port_suggestion": "Suggested available port: {port}",
|
||||
"run_int_use_suggested": "Use suggested port?",
|
||||
"run_int_saved_configs": "Saved Configurations:",
|
||||
"run_int_config_name": "Configuration {num}",
|
||||
"run_int_kt_method": "KT Method:",
|
||||
"run_int_numa_nodes_label": "NUMA Nodes:",
|
||||
"run_int_cpu_threads_label": "CPU Threads:",
|
||||
"run_int_gpu_experts_label": "GPU Experts:",
|
||||
"run_int_tp_size_label": "TP Size:",
|
||||
"run_int_mem_fraction_label": "Memory Fraction:",
|
||||
"run_int_server_label": "Server:",
|
||||
"run_int_kv_cache_label": "KV Cache:",
|
||||
"run_int_chunk_prefill_label": "Chunk Prefill:",
|
||||
"run_int_gpu_prefill_label": "GPU Prefill Thr:",
|
||||
"run_int_tool_parser_label": "Tool Call Parser:",
|
||||
"run_int_reasoning_parser_label": "Reasoning Parser:",
|
||||
"run_int_command_label": "Command:",
|
||||
"run_int_select_config": "Select configuration",
|
||||
"run_int_gpu_select_required": "Please select {tp} GPUs (TP size from saved config)",
|
||||
"run_int_port_check_title": "Port Configuration",
|
||||
"run_int_port_checking": "Checking port {port} availability...",
|
||||
"run_int_port_available": "Port {port} is available",
|
||||
"run_int_saved_config_title": "Saved Configuration",
|
||||
"run_int_save_config_title": "Save Configuration",
|
||||
"run_int_save_config_prompt": "Save this configuration for future use?",
|
||||
"run_int_config_name_prompt": "Configuration name",
|
||||
"run_int_config_name_default": "Config {timestamp}",
|
||||
"run_int_config_saved": "Configuration saved: {name}",
|
||||
"run_int_config_summary": "Configuration Complete",
|
||||
"run_int_model_label": "Model:",
|
||||
"run_int_selected_gpus_label": "Selected GPUs:",
|
||||
# Model command
|
||||
"model_supported_title": "KTransformers Supported Models",
|
||||
"model_column_model": "Model",
|
||||
@@ -282,6 +453,180 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"model_column_name": "Name",
|
||||
"model_column_hf_repo": "HuggingFace Repo",
|
||||
"model_column_aliases": "Aliases",
|
||||
# Model management - new user registry system
|
||||
"model_no_registered_models": "No models registered yet.",
|
||||
"model_scan_hint": "Scan for models: kt model scan",
|
||||
"model_add_hint": "Add a model: kt model add /path/to/model",
|
||||
"model_registered_models_title": "Registered Models",
|
||||
"model_column_format": "Format",
|
||||
"model_column_repo": "Repository",
|
||||
"model_column_sha256": "SHA256",
|
||||
"model_non_moe_hidden_hint": "Detected {count} non-MoE models, use kt model list --all to show all",
|
||||
"model_usage_title": "Common Operations:",
|
||||
"model_usage_info": "View details:",
|
||||
"model_usage_edit": "Edit model:",
|
||||
"model_usage_verify": "Verify integrity:",
|
||||
"model_usage_quant": "Quantize model:",
|
||||
"model_usage_run": "Run model:",
|
||||
"model_usage_scan": "Scan for models:",
|
||||
"model_usage_add": "Add model:",
|
||||
"model_usage_verbose": "View with file details:",
|
||||
"model_no_storage_paths": "No storage paths configured.",
|
||||
"model_add_path_hint": "Add a storage path with: kt config set model.storage_paths /path/to/models",
|
||||
"model_scanning_paths": "Scanning configured storage paths...",
|
||||
"model_scanning_progress": "Scanning: {path}",
|
||||
"model_scan_warnings_title": "Warnings",
|
||||
"model_scan_no_models_found": "No models found in configured paths.",
|
||||
"model_scan_check_paths_hint": "Check your storage paths: kt config get model.storage_paths",
|
||||
"model_scan_min_size_hint": "Folders must be ≥{size}GB to be detected as models.",
|
||||
"model_scan_found_title": "Found {count} new model(s)",
|
||||
"model_column_path": "Path",
|
||||
"model_column_size": "Size",
|
||||
"model_scan_auto_adding": "Auto-adding models...",
|
||||
"model_added": "Added: {name}",
|
||||
"model_add_failed": "Failed to add {name}: {error}",
|
||||
"model_scan_complete": "Scan complete! Added {count} model(s).",
|
||||
"model_scan_interactive_prompt": "Commands: edit <id> | del <id> | done",
|
||||
"model_scan_cmd_edit": "Set custom name for model",
|
||||
"model_scan_cmd_delete": "Skip this model",
|
||||
"model_scan_cmd_done": "Finish and add models",
|
||||
"model_scan_marked_skip": "Skipped model #{id}",
|
||||
"model_scan_invalid_id": "Invalid model ID: {id}",
|
||||
"model_scan_invalid_command": "Invalid command. Use: edit <id> | del <id> | done",
|
||||
"model_scan_edit_model": "Edit model {id}",
|
||||
"model_scan_edit_note": "You can change the model name before adding it to registry",
|
||||
"model_scan_adding_models": "Adding {count} model(s)...",
|
||||
"model_scan_next_steps": "Next Steps",
|
||||
"model_scan_view_hint": "View registered models: kt model list",
|
||||
"model_scan_edit_hint": "Edit model details: kt model edit <name>",
|
||||
"model_scan_no_models_added": "No models were added.",
|
||||
"model_add_path_not_exist": "Error: Path does not exist: {path}",
|
||||
"model_add_not_directory": "Error: Path is not a directory: {path}",
|
||||
"model_add_already_registered": "This path is already registered as: {name}",
|
||||
"model_add_view_hint": "View with: kt model info {name}",
|
||||
"model_add_scanning": "Scanning model files...",
|
||||
"model_add_scan_failed": "Failed to scan model: {error}",
|
||||
"model_add_no_model_files": "No model files found in {path}",
|
||||
"model_add_supported_formats": "Supported: *.safetensors, *.gguf (folder ≥10GB)",
|
||||
"model_add_detected": "Detected: {format} format, {size}, {count} file(s)",
|
||||
"model_add_name_conflict": "Name '{name}' already exists.",
|
||||
"model_add_prompt_name": "Enter a name for this model",
|
||||
"model_add_name_exists": "Name already exists. Please choose another name:",
|
||||
"model_add_configure_repo": "Configure repository information for SHA256 verification?",
|
||||
"model_add_repo_type_prompt": "Select repository type:",
|
||||
"model_add_choice": "Choice",
|
||||
"model_add_repo_id_prompt": "Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)",
|
||||
"model_add_success": "Successfully added model: {name}",
|
||||
"model_add_verify_hint": "Verify integrity: kt model verify {name}",
|
||||
"model_add_edit_later_hint": "Edit details later: kt model edit {name}",
|
||||
"model_add_failed_generic": "Failed to add model: {error}",
|
||||
"model_edit_not_found": "Model '{name}' not found.",
|
||||
"model_edit_list_hint": "List models: kt model list",
|
||||
"model_edit_current_config": "Current Configuration",
|
||||
"model_edit_what_to_edit": "What would you like to edit?",
|
||||
"model_edit_option_name": "Edit name",
|
||||
"model_edit_option_repo": "Configure repository info",
|
||||
"model_edit_option_delete": "Delete this model",
|
||||
"model_edit_option_cancel": "Cancel / Exit",
|
||||
"model_edit_choice_prompt": "Select option",
|
||||
"model_edit_new_name": "Enter new name",
|
||||
"model_edit_name_conflict": "Name '{name}' already exists. Please choose another:",
|
||||
"model_edit_name_updated": "Name updated: {old} → {new}",
|
||||
"model_edit_repo_type_prompt": "Repository type (or enter to remove repo info):",
|
||||
"model_edit_repo_remove": "Remove repository info",
|
||||
"model_edit_repo_id_prompt": "Enter repository ID",
|
||||
"model_edit_repo_removed": "Repository info removed",
|
||||
"model_edit_repo_updated": "Repository configured: {repo_type} → {repo_id}",
|
||||
"model_edit_delete_warning": "Delete model '{name}' from registry?",
|
||||
"model_edit_delete_note": "Note: This only removes the registry entry. Model files in {path} will NOT be deleted.",
|
||||
"model_edit_delete_confirm": "Confirm deletion?",
|
||||
"model_edit_deleted": "Model '{name}' deleted from registry",
|
||||
"model_edit_delete_cancelled": "Deletion cancelled",
|
||||
"model_edit_cancelled": "Edit cancelled",
|
||||
# Model edit - Interactive selection
|
||||
"model_edit_select_title": "Select Model to Edit",
|
||||
"model_edit_select_model": "Select model",
|
||||
"model_edit_invalid_choice": "Invalid choice",
|
||||
"model_edit_no_models": "No models found in registry.",
|
||||
"model_edit_add_hint_scan": "Add models with:",
|
||||
"model_edit_add_hint_add": "Or:",
|
||||
# Model edit - Display
|
||||
"model_edit_gpu_links": "GPU Links:",
|
||||
# Model edit - Menu options
|
||||
"model_edit_manage_gpu_links": "Manage GPU Links",
|
||||
"model_edit_save_changes": "Save changes",
|
||||
"model_edit_has_changes": "(has changes)",
|
||||
"model_edit_no_changes": "(no changes)",
|
||||
# Model edit - Pending changes messages
|
||||
"model_edit_name_pending": "Name will be updated when you save changes.",
|
||||
"model_edit_repo_remove_pending": "Repository info will be removed when you save changes.",
|
||||
"model_edit_repo_update_pending": "Repository info will be updated when you save changes.",
|
||||
# Model edit - GPU link management
|
||||
"model_edit_gpu_links_title": "Manage GPU Links for {name}",
|
||||
"model_edit_current_gpu_links": "Current GPU links:",
|
||||
"model_edit_no_gpu_links": "No GPU links configured.",
|
||||
"model_edit_gpu_options": "Options:",
|
||||
"model_edit_gpu_add": "Add GPU link",
|
||||
"model_edit_gpu_remove": "Remove GPU link",
|
||||
"model_edit_gpu_clear": "Clear all GPU links",
|
||||
"model_edit_gpu_back": "Back to main menu",
|
||||
"model_edit_gpu_choose_option": "Choose option",
|
||||
"model_edit_gpu_none_available": "No GPU models available to link.",
|
||||
"model_edit_gpu_available_models": "Available GPU models:",
|
||||
"model_edit_gpu_already_linked": "(already linked)",
|
||||
"model_edit_gpu_enter_number": "Enter GPU model number to add",
|
||||
"model_edit_gpu_link_pending": "GPU link will be added when you save changes: {name}",
|
||||
"model_edit_gpu_already_exists": "This GPU model is already linked.",
|
||||
"model_edit_gpu_invalid_choice": "Invalid choice.",
|
||||
"model_edit_gpu_invalid_input": "Invalid input.",
|
||||
"model_edit_gpu_none_to_remove": "No GPU links to remove.",
|
||||
"model_edit_gpu_choose_to_remove": "Choose GPU link to remove:",
|
||||
"model_edit_gpu_enter_to_remove": "Enter number to remove",
|
||||
"model_edit_gpu_remove_pending": "GPU link will be removed when you save changes: {name}",
|
||||
"model_edit_gpu_none_to_clear": "No GPU links to clear.",
|
||||
"model_edit_gpu_clear_confirm": "Remove all GPU links?",
|
||||
"model_edit_gpu_clear_pending": "All GPU links will be removed when you save changes.",
|
||||
"model_edit_cancelled_short": "Cancelled.",
|
||||
# Model edit - Save operation
|
||||
"model_edit_no_changes_to_save": "No changes to save.",
|
||||
"model_edit_saving": "Saving changes...",
|
||||
"model_edit_saved": "Changes saved successfully!",
|
||||
"model_edit_updated_config": "Updated Configuration:",
|
||||
"model_edit_repo_changed_warning": "⚠ Repository information has changed.",
|
||||
"model_edit_verify_hint": "Run [cyan]kt model verify[/cyan] to verify model integrity with SHA256 checksums.",
|
||||
"model_edit_discard_changes": "Discard unsaved changes?",
|
||||
"model_info_not_found": "Model '{name}' not found.",
|
||||
"model_info_list_hint": "List all models: kt model list",
|
||||
"model_remove_not_found": "Model '{name}' not found.",
|
||||
"model_remove_list_hint": "List models: kt model list",
|
||||
"model_remove_warning": "Remove model '{name}' from registry?",
|
||||
"model_remove_note": "Note: This only removes the registry entry. Model files will NOT be deleted from {path}.",
|
||||
"model_remove_confirm": "Confirm removal?",
|
||||
"model_remove_cancelled": "Removal cancelled",
|
||||
"model_removed": "Model '{name}' removed from registry",
|
||||
"model_remove_failed": "Failed to remove model: {error}",
|
||||
"model_refresh_checking": "Checking model paths...",
|
||||
"model_refresh_all_valid": "All models are valid! ({count} model(s) checked)",
|
||||
"model_refresh_total": "Total models: {total}",
|
||||
"model_refresh_missing_found": "Found {count} missing model(s)",
|
||||
"model_refresh_suggestions": "Suggested Actions",
|
||||
"model_refresh_remove_hint": "Remove from registry: kt model remove <name>",
|
||||
"model_refresh_rescan_hint": "Re-scan for models: kt model scan",
|
||||
"model_verify_not_found": "Model '{name}' not found.",
|
||||
"model_verify_list_hint": "List models: kt model list",
|
||||
"model_verify_no_repo": "Model '{name}' has no repository information configured.",
|
||||
"model_verify_config_hint": "Configure repository: kt model edit {name}",
|
||||
"model_verify_path_missing": "Model path does not exist: {path}",
|
||||
"model_verify_starting": "Verifying model integrity...",
|
||||
"model_verify_progress": "Repository: {repo_type} → {repo_id}",
|
||||
"model_verify_not_implemented": "SHA256 verification not implemented yet",
|
||||
"model_verify_future_note": "This feature will fetch official SHA256 hashes from {repo_type} and compare with local files.",
|
||||
"model_verify_passed": "Verification passed! All files match official hashes.",
|
||||
"model_verify_failed": "Verification failed! {count} file(s) have hash mismatches.",
|
||||
"model_verify_all_no_repos": "No models have repository information configured.",
|
||||
"model_verify_all_config_hint": "Configure repos using: kt model edit <name>",
|
||||
"model_verify_all_found": "Found {count} model(s) with repository info",
|
||||
"model_verify_all_manual_hint": "Verify specific model: kt model verify <name>",
|
||||
# Coming soon
|
||||
"feature_coming_soon": "This feature is coming soon...",
|
||||
},
|
||||
@@ -465,6 +810,70 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"quant_progress": "正在量化...",
|
||||
"quant_complete": "量化完成!",
|
||||
"quant_input_not_found": "未找到输入模型: {path}",
|
||||
"quant_cpu_threads": "CPU 线程数: {threads}",
|
||||
"quant_numa_nodes": "NUMA 节点数: {nodes}",
|
||||
"quant_time_warning": "量化可能需要 30-60 分钟,具体取决于模型大小。",
|
||||
"quant_disk_analysis": "磁盘空间分析:",
|
||||
"quant_source_size": "源模型大小:",
|
||||
"quant_estimated_size": "预估输出大小:",
|
||||
"quant_available_space": "可用空间:",
|
||||
"quant_insufficient_space": "警告:磁盘空间不足!",
|
||||
"quant_required_space": "所需空间(含20%缓冲):",
|
||||
"quant_shortage": "不足:",
|
||||
"quant_may_fail": "量化可能失败或生成不完整的文件。",
|
||||
"quant_continue_anyway": "仍然继续?",
|
||||
"quant_settings": "量化设置:",
|
||||
"quant_registered": "量化模型已注册:{name}",
|
||||
"quant_view_with": "查看:",
|
||||
"quant_use_with": "使用:",
|
||||
"quant_register_failed": "自动注册模型失败:{error}",
|
||||
"quant_output_exists": "输出路径已存在:{path}",
|
||||
"quant_using_unique": "使用唯一名称:{path}",
|
||||
# Interactive quant
|
||||
"quant_interactive_title": "交互式量化配置",
|
||||
"quant_new_model_notice": "⚠ 注意:部分新模型暂时无法量化(转换脚本未适配),推荐使用原精度进行推理(无需转换权重)。",
|
||||
"quant_no_moe_models": "未找到可量化的 MoE 模型。",
|
||||
"quant_only_moe": "只有 MoE 模型(如 DeepSeek-V3)可以被量化为 AMX 格式。",
|
||||
"quant_add_models": "添加模型:{command}",
|
||||
"quant_moe_available": "可量化的 MoE 模型:",
|
||||
"quant_select_model": "选择要量化的模型",
|
||||
"quant_invalid_choice": "无效选择",
|
||||
"quant_step2_method": "第 2 步:量化方法",
|
||||
"quant_method_label": "量化方法:",
|
||||
"quant_int4_desc": "INT4",
|
||||
"quant_int8_desc": "INT8",
|
||||
"quant_select_method": "选择量化方法",
|
||||
"quant_input_type_label": "输入权重类型:",
|
||||
"quant_fp8_desc": "FP8(适用于 8 位浮点权重)",
|
||||
"quant_fp16_desc": "FP16(适用于 16 位浮点权重)",
|
||||
"quant_bf16_desc": "BF16(适用于 Brain Float 16 权重)",
|
||||
"quant_select_input_type": "选择输入类型",
|
||||
"quant_step3_cpu": "第 3 步:CPU 配置",
|
||||
"quant_cpu_threads_prompt": "CPU 线程数(1 到 {max})",
|
||||
"quant_numa_nodes_prompt": "NUMA 节点数(1 到 {max})",
|
||||
"quant_use_gpu_label": "是否使用 GPU 进行转换?",
|
||||
"quant_gpu_speedup": "GPU 可以显著加快量化速度",
|
||||
"quant_enable_gpu": "启用 GPU 加速?",
|
||||
"quant_step4_output": "第 4 步:输出路径",
|
||||
"quant_default_path": "默认:",
|
||||
"quant_use_default": "使用默认输出路径?",
|
||||
"quant_custom_path": "输入自定义输出路径",
|
||||
"quant_output_exists_warn": "⚠ 输出路径已存在:{path}",
|
||||
"quant_using_unique_name": "→ 使用唯一名称:{path}",
|
||||
"quant_config_summary": "配置摘要",
|
||||
"quant_summary_model": "模型:",
|
||||
"quant_summary_method": "方法:",
|
||||
"quant_summary_input_type": "输入类型:",
|
||||
"quant_summary_cpu_threads": "CPU 线程数:",
|
||||
"quant_summary_numa": "NUMA 节点数:",
|
||||
"quant_summary_gpu": "使用 GPU:",
|
||||
"quant_summary_output": "输出路径:",
|
||||
"quant_start_question": "开始量化?",
|
||||
"quant_cancelled": "已取消",
|
||||
"quant_config_complete": "配置完成",
|
||||
"quant_time_elapsed": "耗时:",
|
||||
"yes": "是",
|
||||
"no": "否",
|
||||
# SFT command
|
||||
"sft_mode_train": "训练模式",
|
||||
"sft_mode_chat": "聊天模式",
|
||||
@@ -522,6 +931,113 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"chat_proxy_detected": "检测到环境中存在代理设置",
|
||||
"chat_proxy_confirm": "是否使用代理连接?",
|
||||
"chat_proxy_disabled": "已在本次会话中禁用代理",
|
||||
"chat_openai_required": "聊天功能需要 OpenAI Python SDK。",
|
||||
"chat_install_hint": "安装命令:",
|
||||
"chat_title": "KTransformers 对话",
|
||||
"chat_server": "服务器",
|
||||
"chat_temperature": "温度",
|
||||
"chat_max_tokens": "最大 tokens",
|
||||
"chat_help_hint": "输入 '/help' 查看命令,'/quit' 退出",
|
||||
"chat_connecting": "正在连接服务器...",
|
||||
"chat_no_models": "服务器上没有可用模型",
|
||||
"chat_model_not_found": "未找到模型 '{model}'。可用模型:{available}",
|
||||
"chat_connected": "已连接到模型:{model}",
|
||||
"chat_connect_failed": "连接服务器失败:{error}",
|
||||
"chat_server_not_running": "请确保模型服务器正在运行:",
|
||||
"chat_user_prompt": "用户",
|
||||
"chat_assistant_prompt": "助手",
|
||||
"chat_generation_error": "生成回复时出错:{error}",
|
||||
"chat_interrupted": "对话已中断。再见!",
|
||||
"chat_history_saved": "历史记录已保存到:{path}",
|
||||
"chat_goodbye": "再见!",
|
||||
"chat_help_title": "可用命令:",
|
||||
"chat_help_content": "/help, /h - 显示此帮助信息\n/quit, /exit, /q - 退出聊天\n/clear, /c - 清除对话历史\n/history, /hist - 显示对话历史\n/info, /i - 显示当前设置\n/retry, /r - 重新生成上一个回复",
|
||||
"chat_history_cleared": "对话历史已清除",
|
||||
"chat_no_history": "暂无对话历史",
|
||||
"chat_history_title": "历史记录({count} 条消息)",
|
||||
"chat_info_title": "当前设置:",
|
||||
"chat_info_content": "温度:{temperature}\n最大 tokens:{max_tokens}\n消息数:{messages}",
|
||||
"chat_retrying": "正在重试上一个回复...",
|
||||
"chat_no_retry": "没有可重试的回复",
|
||||
"chat_unknown_command": "未知命令:{command}",
|
||||
"chat_unknown_hint": "输入 /help 查看可用命令",
|
||||
# Run Interactive
|
||||
"run_int_no_moe_models": "未找到 MoE GPU 模型。",
|
||||
"run_int_add_models": "添加模型:kt model scan",
|
||||
"run_int_list_all": "列出所有模型:kt model list --all",
|
||||
"run_int_step1_title": "第 1 步:选择模型(GPU MoE 模型)",
|
||||
"run_int_select_model": "选择模型",
|
||||
"run_int_step2_title": "第 2 步:选择推理方法",
|
||||
"run_int_method_raw": "RAW 精度(FP8/FP8_PERCHANNEL/BF16/RAWINT4)",
|
||||
"run_int_method_amx": "AMX 量化(INT4/INT8)",
|
||||
"run_int_method_gguf": "GGUF(Llamafile)",
|
||||
"run_int_method_saved": "使用已保存的配置",
|
||||
"run_int_select_method": "选择推理方法",
|
||||
"run_int_raw_precision": "RAW 精度:",
|
||||
"run_int_select_precision": "选择精度",
|
||||
"run_int_amx_method": "AMX 方法:",
|
||||
"run_int_select_amx": "选择 AMX 方法",
|
||||
"run_int_step3_title": "第 3 步:NUMA 和 CPU 配置",
|
||||
"run_int_numa_nodes": "NUMA 节点数(1-{max})",
|
||||
"run_int_cpu_threads": "每个 NUMA 的 CPU 线程数(1-{max})",
|
||||
"run_int_amx_warning": "⚠ 警告:AMX INT4/INT8 需要兼容的 CPU。检查命令:kt doctor",
|
||||
"run_int_step4_title": "第 4 步:GPU 专家配置",
|
||||
"run_int_gpu_experts": "每层 GPU 专家数(0-{max})",
|
||||
"run_int_gpu_experts_info": "总专家数:{total},每 token 激活:{active}",
|
||||
"run_int_step5_title": "第 5 步:KV Cache 配置",
|
||||
"run_int_kv_cache_size": "KV Cache 大小(tokens)",
|
||||
"run_int_chunk_prefill": "启用分块预填充?",
|
||||
"run_int_chunk_size": "分块预填充大小(tokens)",
|
||||
"run_int_gpu_prefill_threshold": "GPU 预填充阈值(tokens)",
|
||||
"run_int_step6_title": "第 6 步:GPU 选择和张量并行",
|
||||
"run_int_available_gpus": "可用 GPU:",
|
||||
"run_int_gpu_id": "GPU {id}",
|
||||
"run_int_vram_info": "{name}(总计 {total:.1f}GB,空闲 {free:.1f}GB)",
|
||||
"run_int_select_gpus": "选择 GPU ID(逗号分隔)",
|
||||
"run_int_invalid_gpu_range": "所有 GPU ID 必须在 0 到 {max} 之间",
|
||||
"run_int_tp_size": "TP 大小(必须是 2 的幂:1,2,4,8...)",
|
||||
"run_int_tp_mismatch": "TP 大小必须与选择的 GPU 数量匹配({count})",
|
||||
"run_int_tp_not_power_of_2": "TP 大小必须是 2 的幂",
|
||||
"run_int_mem_fraction": "静态内存占用比例(0.0-1.0)",
|
||||
"run_int_using_saved_mem": "使用已保存的内存占用比例:{fraction}",
|
||||
"run_int_step7_title": "第 7 步:解析器配置(可选)",
|
||||
"run_int_tool_call_parser": "工具调用解析器(按回车跳过)",
|
||||
"run_int_reasoning_parser": "推理解析器(按回车跳过)",
|
||||
"run_int_step8_title": "第 8 步:主机和端口配置",
|
||||
"run_int_host": "主机",
|
||||
"run_int_port": "端口",
|
||||
"run_int_port_occupied": "⚠ 端口 {port} 已被占用",
|
||||
"run_int_port_suggestion": "建议使用可用端口:{port}",
|
||||
"run_int_use_suggested": "使用建议的端口?",
|
||||
"run_int_saved_configs": "已保存的配置:",
|
||||
"run_int_config_name": "配置 {num}",
|
||||
"run_int_kt_method": "KT 方法:",
|
||||
"run_int_numa_nodes_label": "NUMA 节点:",
|
||||
"run_int_cpu_threads_label": "CPU 线程:",
|
||||
"run_int_gpu_experts_label": "GPU 专家:",
|
||||
"run_int_tp_size_label": "TP 大小:",
|
||||
"run_int_mem_fraction_label": "内存占用比例:",
|
||||
"run_int_server_label": "服务器:",
|
||||
"run_int_kv_cache_label": "KV Cache:",
|
||||
"run_int_chunk_prefill_label": "分块预填充:",
|
||||
"run_int_gpu_prefill_label": "GPU 预填充阈值:",
|
||||
"run_int_tool_parser_label": "工具调用解析器:",
|
||||
"run_int_reasoning_parser_label": "推理解析器:",
|
||||
"run_int_command_label": "命令:",
|
||||
"run_int_select_config": "选择配置",
|
||||
"run_int_gpu_select_required": "请选择 {tp} 个 GPU(来自已保存配置的 TP 大小)",
|
||||
"run_int_port_check_title": "端口配置",
|
||||
"run_int_port_checking": "正在检查端口 {port} 可用性...",
|
||||
"run_int_port_available": "端口 {port} 可用",
|
||||
"run_int_saved_config_title": "已保存的配置",
|
||||
"run_int_save_config_title": "保存配置",
|
||||
"run_int_save_config_prompt": "保存此配置以供将来使用?",
|
||||
"run_int_config_name_prompt": "配置名称",
|
||||
"run_int_config_name_default": "配置 {timestamp}",
|
||||
"run_int_config_saved": "配置已保存:{name}",
|
||||
"run_int_config_summary": "配置完成",
|
||||
"run_int_model_label": "模型:",
|
||||
"run_int_selected_gpus_label": "已选择的 GPU:",
|
||||
# Model command
|
||||
"model_supported_title": "KTransformers 支持的模型",
|
||||
"model_column_model": "模型",
|
||||
@@ -557,6 +1073,180 @@ MESSAGES: dict[str, dict[str, str]] = {
|
||||
"model_column_name": "名称",
|
||||
"model_column_hf_repo": "HuggingFace 仓库",
|
||||
"model_column_aliases": "别名",
|
||||
# Model management - new user registry system
|
||||
"model_no_registered_models": "尚未注册任何模型。",
|
||||
"model_scan_hint": "扫描模型: kt model scan",
|
||||
"model_add_hint": "添加模型: kt model add /path/to/model",
|
||||
"model_registered_models_title": "已注册的模型",
|
||||
"model_column_format": "格式",
|
||||
"model_column_repo": "仓库",
|
||||
"model_column_sha256": "SHA256",
|
||||
"model_non_moe_hidden_hint": "检测到 {count} 个非MoE模型,使用 kt model list --all 展示全部",
|
||||
"model_usage_title": "常用操作:",
|
||||
"model_usage_info": "查看详情:",
|
||||
"model_usage_edit": "编辑模型:",
|
||||
"model_usage_verify": "校验权重:",
|
||||
"model_usage_quant": "量化模型:",
|
||||
"model_usage_run": "运行模型:",
|
||||
"model_usage_scan": "扫描模型:",
|
||||
"model_usage_add": "添加模型:",
|
||||
"model_usage_verbose": "查看包含文件详情:",
|
||||
"model_no_storage_paths": "未配置存储路径。",
|
||||
"model_add_path_hint": "添加存储路径: kt config set model.storage_paths /path/to/models",
|
||||
"model_scanning_paths": "正在扫描配置的存储路径...",
|
||||
"model_scanning_progress": "扫描中: {path}",
|
||||
"model_scan_warnings_title": "警告",
|
||||
"model_scan_no_models_found": "在配置的路径中未找到模型。",
|
||||
"model_scan_check_paths_hint": "检查存储路径: kt config get model.storage_paths",
|
||||
"model_scan_min_size_hint": "文件夹必须 ≥{size}GB 才能被识别为模型。",
|
||||
"model_scan_found_title": "发现 {count} 个新模型",
|
||||
"model_column_path": "路径",
|
||||
"model_column_size": "大小",
|
||||
"model_scan_auto_adding": "正在自动添加模型...",
|
||||
"model_added": "已添加: {name}",
|
||||
"model_add_failed": "添加 {name} 失败: {error}",
|
||||
"model_scan_complete": "扫描完成!已添加 {count} 个模型。",
|
||||
"model_scan_interactive_prompt": "命令: edit <id> | del <id> | done",
|
||||
"model_scan_cmd_edit": "设置模型自定义名称和仓库",
|
||||
"model_scan_cmd_delete": "跳过此模型",
|
||||
"model_scan_cmd_done": "完成并添加模型",
|
||||
"model_scan_marked_skip": "已跳过模型 #{id}",
|
||||
"model_scan_invalid_id": "无效的模型 ID: {id}",
|
||||
"model_scan_invalid_command": "无效命令。使用: edit <id> | del <id> | done",
|
||||
"model_scan_edit_model": "编辑模型 {id}",
|
||||
"model_scan_edit_note": "您可以在添加到注册表前更改模型名称和配置仓库信息",
|
||||
"model_scan_adding_models": "正在添加 {count} 个模型...",
|
||||
"model_scan_next_steps": "后续步骤",
|
||||
"model_scan_view_hint": "查看已注册模型: kt model list",
|
||||
"model_scan_edit_hint": "编辑模型详情: kt model edit <name>",
|
||||
"model_scan_no_models_added": "未添加任何模型。",
|
||||
"model_add_path_not_exist": "错误: 路径不存在: {path}",
|
||||
"model_add_not_directory": "错误: 路径不是目录: {path}",
|
||||
"model_add_already_registered": "此路径已注册为: {name}",
|
||||
"model_add_view_hint": "查看: kt model info {name}",
|
||||
"model_add_scanning": "正在扫描模型文件...",
|
||||
"model_add_scan_failed": "扫描模型失败: {error}",
|
||||
"model_add_no_model_files": "在 {path} 中未找到模型文件",
|
||||
"model_add_supported_formats": "支持: *.safetensors, *.gguf (文件夹 ≥10GB)",
|
||||
"model_add_detected": "检测到: {format} 格式, {size}, {count} 个文件",
|
||||
"model_add_name_conflict": "名称 '{name}' 已存在。",
|
||||
"model_add_prompt_name": "为此模型输入名称",
|
||||
"model_add_name_exists": "名称已存在。请选择其他名称:",
|
||||
"model_add_configure_repo": "配置仓库信息以进行 SHA256 验证?",
|
||||
"model_add_repo_type_prompt": "选择仓库类型:",
|
||||
"model_add_choice": "选择",
|
||||
"model_add_repo_id_prompt": "输入仓库 ID (例如: deepseek-ai/DeepSeek-V3)",
|
||||
"model_add_success": "成功添加模型: {name}",
|
||||
"model_add_verify_hint": "验证完整性: kt model verify {name}",
|
||||
"model_add_edit_later_hint": "稍后编辑详情: kt model edit {name}",
|
||||
"model_add_failed_generic": "添加模型失败: {error}",
|
||||
"model_edit_not_found": "未找到模型 '{name}'。",
|
||||
"model_edit_list_hint": "列出模型: kt model list",
|
||||
"model_edit_current_config": "当前配置",
|
||||
"model_edit_what_to_edit": "您想编辑什么?",
|
||||
"model_edit_option_name": "编辑名称",
|
||||
"model_edit_option_repo": "配置仓库信息",
|
||||
"model_edit_option_delete": "删除此模型",
|
||||
"model_edit_option_cancel": "取消 / 退出",
|
||||
"model_edit_choice_prompt": "选择选项",
|
||||
"model_edit_new_name": "输入新名称",
|
||||
"model_edit_name_conflict": "名称 '{name}' 已存在。请选择其他名称:",
|
||||
"model_edit_name_updated": "名称已更新: {old} → {new}",
|
||||
"model_edit_repo_type_prompt": "仓库类型 (或按回车删除仓库信息):",
|
||||
"model_edit_repo_remove": "删除仓库信息",
|
||||
"model_edit_repo_id_prompt": "输入仓库 ID",
|
||||
"model_edit_repo_removed": "仓库信息已删除",
|
||||
"model_edit_repo_updated": "仓库已配置: {repo_type} → {repo_id}",
|
||||
"model_edit_delete_warning": "从注册表中删除模型 '{name}'?",
|
||||
"model_edit_delete_note": "注意: 这只会删除注册表条目。{path} 中的模型文件不会被删除。",
|
||||
"model_edit_delete_confirm": "确认删除?",
|
||||
"model_edit_deleted": "模型 '{name}' 已从注册表中删除",
|
||||
"model_edit_delete_cancelled": "删除已取消",
|
||||
"model_edit_cancelled": "编辑已取消",
|
||||
# Model edit - Interactive selection
|
||||
"model_edit_select_title": "选择要编辑的模型",
|
||||
"model_edit_select_model": "选择模型",
|
||||
"model_edit_invalid_choice": "无效选择",
|
||||
"model_edit_no_models": "注册表中未找到模型。",
|
||||
"model_edit_add_hint_scan": "添加模型:",
|
||||
"model_edit_add_hint_add": "或:",
|
||||
# Model edit - Display
|
||||
"model_edit_gpu_links": "GPU 链接:",
|
||||
# Model edit - Menu options
|
||||
"model_edit_manage_gpu_links": "管理 GPU 链接",
|
||||
"model_edit_save_changes": "保存更改",
|
||||
"model_edit_has_changes": "(有更改)",
|
||||
"model_edit_no_changes": "(无更改)",
|
||||
# Model edit - Pending changes messages
|
||||
"model_edit_name_pending": "名称将在保存更改时更新。",
|
||||
"model_edit_repo_remove_pending": "仓库信息将在保存更改时删除。",
|
||||
"model_edit_repo_update_pending": "仓库信息将在保存更改时更新。",
|
||||
# Model edit - GPU link management
|
||||
"model_edit_gpu_links_title": "管理 {name} 的 GPU 链接",
|
||||
"model_edit_current_gpu_links": "当前 GPU 链接:",
|
||||
"model_edit_no_gpu_links": "未配置 GPU 链接。",
|
||||
"model_edit_gpu_options": "选项:",
|
||||
"model_edit_gpu_add": "添加 GPU 链接",
|
||||
"model_edit_gpu_remove": "删除 GPU 链接",
|
||||
"model_edit_gpu_clear": "清除所有 GPU 链接",
|
||||
"model_edit_gpu_back": "返回主菜单",
|
||||
"model_edit_gpu_choose_option": "选择选项",
|
||||
"model_edit_gpu_none_available": "没有可链接的 GPU 模型。",
|
||||
"model_edit_gpu_available_models": "可用的 GPU 模型:",
|
||||
"model_edit_gpu_already_linked": "(已链接)",
|
||||
"model_edit_gpu_enter_number": "输入要添加的 GPU 模型编号",
|
||||
"model_edit_gpu_link_pending": "GPU 链接将在保存更改时添加: {name}",
|
||||
"model_edit_gpu_already_exists": "此 GPU 模型已链接。",
|
||||
"model_edit_gpu_invalid_choice": "无效选择。",
|
||||
"model_edit_gpu_invalid_input": "无效输入。",
|
||||
"model_edit_gpu_none_to_remove": "没有可删除的 GPU 链接。",
|
||||
"model_edit_gpu_choose_to_remove": "选择要删除的 GPU 链接:",
|
||||
"model_edit_gpu_enter_to_remove": "输入要删除的编号",
|
||||
"model_edit_gpu_remove_pending": "GPU 链接将在保存更改时删除: {name}",
|
||||
"model_edit_gpu_none_to_clear": "没有可清除的 GPU 链接。",
|
||||
"model_edit_gpu_clear_confirm": "删除所有 GPU 链接?",
|
||||
"model_edit_gpu_clear_pending": "所有 GPU 链接将在保存更改时删除。",
|
||||
"model_edit_cancelled_short": "已取消。",
|
||||
# Model edit - Save operation
|
||||
"model_edit_no_changes_to_save": "没有更改可保存。",
|
||||
"model_edit_saving": "正在保存更改...",
|
||||
"model_edit_saved": "更改保存成功!",
|
||||
"model_edit_updated_config": "更新后的配置:",
|
||||
"model_edit_repo_changed_warning": "⚠ 仓库信息已更改。",
|
||||
"model_edit_verify_hint": "运行 [cyan]kt model verify[/cyan] 以使用 SHA256 校验和验证模型完整性。",
|
||||
"model_edit_discard_changes": "放弃未保存的更改?",
|
||||
"model_info_not_found": "未找到模型 '{name}'。",
|
||||
"model_info_list_hint": "列出所有模型: kt model list",
|
||||
"model_remove_not_found": "未找到模型 '{name}'。",
|
||||
"model_remove_list_hint": "列出模型: kt model list",
|
||||
"model_remove_warning": "从注册表中删除模型 '{name}'?",
|
||||
"model_remove_note": "注意: 这只会删除注册表条目。模型文件不会从 {path} 中删除。",
|
||||
"model_remove_confirm": "确认删除?",
|
||||
"model_remove_cancelled": "删除已取消",
|
||||
"model_removed": "模型 '{name}' 已从注册表中删除",
|
||||
"model_remove_failed": "删除模型失败: {error}",
|
||||
"model_refresh_checking": "正在检查模型路径...",
|
||||
"model_refresh_all_valid": "所有模型都有效! (已检查 {count} 个模型)",
|
||||
"model_refresh_total": "总模型数: {total}",
|
||||
"model_refresh_missing_found": "发现 {count} 个缺失的模型",
|
||||
"model_refresh_suggestions": "建议操作",
|
||||
"model_refresh_remove_hint": "从注册表中删除: kt model remove <name>",
|
||||
"model_refresh_rescan_hint": "重新扫描模型: kt model scan",
|
||||
"model_verify_not_found": "未找到模型 '{name}'。",
|
||||
"model_verify_list_hint": "列出模型: kt model list",
|
||||
"model_verify_no_repo": "模型 '{name}' 未配置仓库信息。",
|
||||
"model_verify_config_hint": "配置仓库: kt model edit {name}",
|
||||
"model_verify_path_missing": "模型路径不存在: {path}",
|
||||
"model_verify_starting": "正在验证模型完整性...",
|
||||
"model_verify_progress": "仓库: {repo_type} → {repo_id}",
|
||||
"model_verify_not_implemented": "SHA256 验证尚未实现",
|
||||
"model_verify_future_note": "此功能将从 {repo_type} 获取官方 SHA256 哈希值并与本地文件进行比较。",
|
||||
"model_verify_passed": "验证通过!所有文件都与官方哈希匹配。",
|
||||
"model_verify_failed": "验证失败!{count} 个文件的哈希不匹配。",
|
||||
"model_verify_all_no_repos": "没有模型配置了仓库信息。",
|
||||
"model_verify_all_config_hint": "配置仓库使用: kt model edit <name>",
|
||||
"model_verify_all_found": "发现 {count} 个配置了仓库信息的模型",
|
||||
"model_verify_all_manual_hint": "验证特定模型: kt model verify <name>",
|
||||
# Coming soon
|
||||
"feature_coming_soon": "此功能即将推出...",
|
||||
},
|
||||
|
||||
@@ -5,6 +5,10 @@ KTransformers CLI - A unified command-line interface for KTransformers.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
# Suppress numpy subnormal warnings
|
||||
warnings.filterwarnings("ignore", message="The value of the smallest subnormal")
|
||||
|
||||
import typer
|
||||
|
||||
@@ -28,6 +32,7 @@ def _get_help(key: str) -> str:
|
||||
"run": {"en": "Start model inference server", "zh": "启动模型推理服务器"},
|
||||
"chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"},
|
||||
"quant": {"en": "Quantize model weights", "zh": "量化模型权重"},
|
||||
"edit": {"en": "Edit model information", "zh": "编辑模型信息"},
|
||||
"bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"},
|
||||
"microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"},
|
||||
"doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"},
|
||||
@@ -43,7 +48,7 @@ def _get_help(key: str) -> str:
|
||||
app = typer.Typer(
|
||||
name="kt",
|
||||
help="KTransformers CLI - A unified command-line interface for KTransformers.",
|
||||
no_args_is_help=True,
|
||||
no_args_is_help=False, # Handle no-args case manually to support first-run setup
|
||||
add_completion=False, # Use static completion scripts instead of dynamic completion
|
||||
rich_markup_mode="rich",
|
||||
)
|
||||
@@ -66,20 +71,7 @@ def _update_help_texts() -> None:
|
||||
group_info.help = _get_help(group_info.name)
|
||||
|
||||
|
||||
# Register commands
|
||||
app.command(name="version", help="Show version information")(version.version)
|
||||
# Run command is handled specially in main() to allow extra args
|
||||
# (not registered here to avoid typer's argument parsing)
|
||||
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
|
||||
app.command(name="quant", help="Quantize model weights")(quant.quant)
|
||||
app.command(name="bench", help="Run full benchmark")(bench.bench)
|
||||
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
|
||||
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
|
||||
|
||||
# Register sub-apps
|
||||
app.add_typer(model.app, name="model", help="Manage models and storage paths")
|
||||
app.add_typer(config.app, name="config", help="Manage configuration")
|
||||
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
|
||||
# Commands are registered later after tui_command is defined
|
||||
|
||||
|
||||
def check_first_run() -> None:
|
||||
@@ -116,7 +108,7 @@ def _show_first_run_setup(settings) -> None:
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
|
||||
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location
|
||||
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -140,15 +132,8 @@ def _show_first_run_setup(settings) -> None:
|
||||
console.print(" [cyan][2][/cyan] 中文 (Chinese)")
|
||||
console.print()
|
||||
|
||||
while True:
|
||||
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
|
||||
|
||||
if choice == "1":
|
||||
lang = "en"
|
||||
break
|
||||
elif choice == "2":
|
||||
lang = "zh"
|
||||
break
|
||||
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
|
||||
lang = "en" if choice == "1" else "zh"
|
||||
|
||||
# Save language setting
|
||||
settings.set("general.language", lang)
|
||||
@@ -161,6 +146,131 @@ def _show_first_run_setup(settings) -> None:
|
||||
else:
|
||||
console.print("[green]✓[/green] Language set to English")
|
||||
|
||||
# Model discovery section
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print("[bold]发现模型权重[/bold]")
|
||||
console.print()
|
||||
console.print("[dim]扫描系统中已有的模型权重文件,以便快速添加到模型列表。[/dim]")
|
||||
console.print()
|
||||
console.print(" [cyan][1][/cyan] 全局扫描 (自动扫描所有非系统路径)")
|
||||
console.print(" [cyan][2][/cyan] 手动指定路径 (可添加多个)")
|
||||
console.print(" [cyan][3][/cyan] 跳过 (稍后手动添加)")
|
||||
console.print()
|
||||
scan_choice = Prompt.ask("选择扫描方式", choices=["1", "2", "3"], default="1")
|
||||
else:
|
||||
console.print("[bold]Discover Model Weights[/bold]")
|
||||
console.print()
|
||||
console.print("[dim]Scan existing model weights on your system to quickly add them to the model list.[/dim]")
|
||||
console.print()
|
||||
console.print(" [cyan][1][/cyan] Global scan (auto-scan all non-system paths)")
|
||||
console.print(" [cyan][2][/cyan] Manual paths (add multiple paths)")
|
||||
console.print(" [cyan][3][/cyan] Skip (add manually later)")
|
||||
console.print()
|
||||
scan_choice = Prompt.ask("Select scan method", choices=["1", "2", "3"], default="1")
|
||||
|
||||
if scan_choice == "1":
|
||||
# Global scan
|
||||
from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary
|
||||
|
||||
console.print()
|
||||
try:
|
||||
total_found, new_found, registered = discover_and_register_global(
|
||||
min_size_gb=2.0, max_depth=6, show_progress=True, lang=lang
|
||||
)
|
||||
|
||||
format_discovery_summary(
|
||||
total_found=total_found,
|
||||
new_found=new_found,
|
||||
registered=registered,
|
||||
lang=lang,
|
||||
show_models=True,
|
||||
max_show=10,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Scan failed - {e}[/yellow]")
|
||||
|
||||
elif scan_choice == "2":
|
||||
# Manual path specification
|
||||
from kt_kernel.cli.utils.model_discovery import discover_and_register_path
|
||||
import os
|
||||
|
||||
discovered_paths = set() # Track paths discovered in this session
|
||||
total_registered = []
|
||||
|
||||
while True:
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
path = Prompt.ask("输入要扫描的路径 (例如: /mnt/data/models)")
|
||||
else:
|
||||
path = Prompt.ask("Enter path to scan (e.g., /mnt/data/models)")
|
||||
|
||||
# Expand and validate path
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
if not os.path.exists(path):
|
||||
if lang == "zh":
|
||||
console.print(f"[yellow]警告: 路径不存在: {path}[/yellow]")
|
||||
else:
|
||||
console.print(f"[yellow]Warning: Path does not exist: {path}[/yellow]")
|
||||
continue
|
||||
|
||||
if not os.path.isdir(path):
|
||||
if lang == "zh":
|
||||
console.print(f"[yellow]警告: 不是一个目录: {path}[/yellow]")
|
||||
else:
|
||||
console.print(f"[yellow]Warning: Not a directory: {path}[/yellow]")
|
||||
continue
|
||||
|
||||
# Scan this path
|
||||
console.print()
|
||||
try:
|
||||
total_found, new_found, registered = discover_and_register_path(
|
||||
path=path, min_size_gb=2.0, existing_paths=discovered_paths, show_progress=True, lang=lang
|
||||
)
|
||||
|
||||
# Update discovered paths
|
||||
for model in registered:
|
||||
discovered_paths.add(model.path)
|
||||
total_registered.extend(registered)
|
||||
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 在此路径找到 {total_found} 个模型,其中 {new_found} 个为新模型")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Found {total_found} models in this path, {new_found} are new")
|
||||
|
||||
if new_found > 0:
|
||||
for model in registered[:5]:
|
||||
console.print(f" • {model.name} ({model.format})")
|
||||
|
||||
if len(registered) > 5:
|
||||
if lang == "zh":
|
||||
console.print(f" [dim]... 还有 {len(registered) - 5} 个新模型[/dim]")
|
||||
else:
|
||||
console.print(f" [dim]... and {len(registered) - 5} more new models[/dim]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error scanning path: {e}[/red]")
|
||||
|
||||
# Ask if continue
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
continue_scan = Confirm.ask("是否继续添加其他路径?", default=False)
|
||||
else:
|
||||
continue_scan = Confirm.ask("Continue adding more paths?", default=False)
|
||||
|
||||
if not continue_scan:
|
||||
break
|
||||
|
||||
if total_registered:
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 总共发现 {len(total_registered)} 个新模型")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Total {len(total_registered)} new models discovered")
|
||||
|
||||
# Model storage path selection
|
||||
console.print()
|
||||
console.print(f"[bold]{t('setup_model_path_title')}[/bold]")
|
||||
@@ -174,16 +284,7 @@ def _show_first_run_setup(settings) -> None:
|
||||
console.print()
|
||||
|
||||
if locations:
|
||||
# Scan for models in each location
|
||||
console.print(f"[dim]{t('setup_scanning_models')}[/dim]")
|
||||
location_models: dict[str, list] = {}
|
||||
for loc in locations[:5]:
|
||||
models = scan_models_in_location(loc, max_depth=2)
|
||||
if models:
|
||||
location_models[loc.path] = models
|
||||
console.print()
|
||||
|
||||
# Show options
|
||||
# Show storage location options
|
||||
for i, loc in enumerate(locations[:5], 1): # Show top 5 options
|
||||
available = format_size_gb(loc.available_gb)
|
||||
total = format_size_gb(loc.total_gb)
|
||||
@@ -194,22 +295,8 @@ def _show_first_run_setup(settings) -> None:
|
||||
else:
|
||||
option_str = t("setup_disk_option", path=loc.path, available=available, total=total)
|
||||
|
||||
# Add model count if any
|
||||
if loc.path in location_models:
|
||||
model_count = len(location_models[loc.path])
|
||||
option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]"
|
||||
|
||||
console.print(f" [cyan][{i}][/cyan] {option_str}")
|
||||
|
||||
# Show first few models found in this location
|
||||
if loc.path in location_models:
|
||||
for model in location_models[loc.path][:3]: # Show up to 3 models
|
||||
size_str = format_size_gb(model.size_gb)
|
||||
console.print(f" [dim]• {model.name} ({size_str})[/dim]")
|
||||
if len(location_models[loc.path]) > 3:
|
||||
remaining = len(location_models[loc.path]) - 3
|
||||
console.print(f" [dim] ... +{remaining} more[/dim]")
|
||||
|
||||
# Custom path option
|
||||
custom_idx = min(len(locations), 5) + 1
|
||||
console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}")
|
||||
@@ -323,51 +410,28 @@ def _install_shell_completion() -> None:
|
||||
|
||||
# Detect current shell
|
||||
shell = os.environ.get("SHELL", "")
|
||||
if "zsh" in shell:
|
||||
shell_name = "zsh"
|
||||
elif "fish" in shell:
|
||||
shell_name = "fish"
|
||||
else:
|
||||
shell_name = "bash"
|
||||
shell_name = "zsh" if "zsh" in shell else "fish" if "fish" in shell else "bash"
|
||||
|
||||
try:
|
||||
cli_dir = Path(__file__).parent
|
||||
completions_dir = cli_dir / "completions"
|
||||
home = Path.home()
|
||||
|
||||
installed = False
|
||||
def install_completion(src_name: str, dest_dir: Path, dest_name: str) -> None:
|
||||
"""Install completion file from source to destination."""
|
||||
src_file = completions_dir / src_name
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_dir / dest_name)
|
||||
|
||||
if shell_name == "bash":
|
||||
# Use XDG standard location for bash-completion (auto-loaded)
|
||||
src_file = completions_dir / "kt-completion.bash"
|
||||
dest_dir = home / ".local" / "share" / "bash-completion" / "completions"
|
||||
dest_file = dest_dir / "kt"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
|
||||
install_completion(
|
||||
"kt-completion.bash", home / ".local" / "share" / "bash-completion" / "completions", "kt"
|
||||
)
|
||||
elif shell_name == "zsh":
|
||||
src_file = completions_dir / "_kt"
|
||||
dest_dir = home / ".zfunc"
|
||||
dest_file = dest_dir / "_kt"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
|
||||
install_completion("_kt", home / ".zfunc", "_kt")
|
||||
elif shell_name == "fish":
|
||||
# Fish auto-loads from this directory
|
||||
src_file = completions_dir / "kt.fish"
|
||||
dest_dir = home / ".config" / "fish" / "completions"
|
||||
dest_file = dest_dir / "kt.fish"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
install_completion("kt.fish", home / ".config" / "fish" / "completions", "kt.fish")
|
||||
|
||||
# Mark as installed
|
||||
settings.set("general._completion_installed", True)
|
||||
@@ -403,6 +467,20 @@ def _apply_saved_language() -> None:
|
||||
set_lang(lang)
|
||||
|
||||
|
||||
app.command(name="version", help="Show version information")(version.version)
|
||||
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
|
||||
app.command(name="quant", help="Quantize model weights")(quant.quant)
|
||||
app.command(name="edit", help="Edit model information")(model.edit_model)
|
||||
app.command(name="bench", help="Run full benchmark")(bench.bench)
|
||||
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
|
||||
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
|
||||
|
||||
# Register sub-apps
|
||||
app.add_typer(model.app, name="model", help="Manage models and storage paths")
|
||||
app.add_typer(config.app, name="config", help="Manage configuration")
|
||||
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
# Apply saved language setting first (before anything else for correct help display)
|
||||
@@ -414,7 +492,7 @@ def main():
|
||||
# Check for first run (but not for certain commands)
|
||||
# Skip first-run check for: --help, config commands, version
|
||||
args = sys.argv[1:] if len(sys.argv) > 1 else []
|
||||
skip_commands = ["--help", "-h", "config", "version", "--version"]
|
||||
skip_commands = ["--help", "-h", "config", "version", "--version", "--no-tui"]
|
||||
|
||||
should_check_first_run = True
|
||||
for arg in args:
|
||||
@@ -422,12 +500,35 @@ def main():
|
||||
should_check_first_run = False
|
||||
break
|
||||
|
||||
# Handle no arguments case
|
||||
if not args:
|
||||
# Check if this is first run
|
||||
from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE, get_settings
|
||||
|
||||
is_first_run = False
|
||||
if not DEFAULT_CONFIG_FILE.exists():
|
||||
is_first_run = True
|
||||
else:
|
||||
settings = get_settings()
|
||||
if not settings.get("general._initialized"):
|
||||
is_first_run = True
|
||||
|
||||
if is_first_run:
|
||||
# First run - start initialization
|
||||
_install_shell_completion()
|
||||
check_first_run()
|
||||
return
|
||||
else:
|
||||
# Not first run - show help
|
||||
app(["--help"])
|
||||
return
|
||||
|
||||
# Auto-install shell completion on first run
|
||||
if should_check_first_run:
|
||||
_install_shell_completion()
|
||||
|
||||
# Check first run before running commands
|
||||
if should_check_first_run and args:
|
||||
if should_check_first_run:
|
||||
check_first_run()
|
||||
|
||||
# Handle "run" command specially to pass through unknown options
|
||||
|
||||
413
kt-kernel/python/cli/utils/analyze_moe_model.py
Normal file
413
kt-kernel/python/cli/utils/analyze_moe_model.py
Normal file
@@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速分析 MoE 模型 - 基于 config.json
|
||||
(复用 sglang 的模型注册表和判断逻辑)
|
||||
"""
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def _get_sglang_moe_architectures():
|
||||
"""
|
||||
从 sglang 的模型注册表获取所有 MoE 架构
|
||||
|
||||
复用 sglang 的代码,这样 sglang 更新后自动支持新模型
|
||||
"""
|
||||
try:
|
||||
import sys
|
||||
|
||||
# 添加 sglang 路径到 sys.path
|
||||
sglang_path = Path("/mnt/data2/ljq/sglang/python")
|
||||
if sglang_path.exists() and str(sglang_path) not in sys.path:
|
||||
sys.path.insert(0, str(sglang_path))
|
||||
|
||||
# 直接导入 sglang 的 ModelRegistry
|
||||
# 注意:这需要 sglang 及其依赖正确安装
|
||||
from sglang.srt.models.registry import ModelRegistry
|
||||
|
||||
# 获取所有支持的架构
|
||||
supported_archs = ModelRegistry.get_supported_archs()
|
||||
|
||||
# 过滤出 MoE 模型(名称包含 Moe)
|
||||
moe_archs = {arch for arch in supported_archs if "Moe" in arch or "moe" in arch.lower()}
|
||||
|
||||
# 手动添加一些不带 "Moe" 字样但是 MoE 模型的架构
|
||||
# DeepSeek V2/V3 系列
|
||||
deepseek_moe = {arch for arch in supported_archs if arch.startswith("Deepseek") or arch.startswith("deepseek")}
|
||||
moe_archs.update(deepseek_moe)
|
||||
|
||||
# DBRX 也是 MoE 模型
|
||||
dbrx_moe = {arch for arch in supported_archs if "DBRX" in arch or "dbrx" in arch.lower()}
|
||||
moe_archs.update(dbrx_moe)
|
||||
|
||||
# Grok 也是 MoE 模型
|
||||
grok_moe = {arch for arch in supported_archs if "Grok" in arch or "grok" in arch.lower()}
|
||||
moe_archs.update(grok_moe)
|
||||
|
||||
return moe_archs
|
||||
except Exception as e:
|
||||
# 如果 sglang 不可用,返回空集合
|
||||
# 这种情况下,后续会使用配置文件中的其他判断方法
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.")
|
||||
return set()
|
||||
|
||||
|
||||
# 获取 MoE 架构列表(优先从 sglang 获取)
|
||||
MOE_ARCHITECTURES = _get_sglang_moe_architectures()
|
||||
|
||||
|
||||
def _get_cache_file():
|
||||
"""获取集中式缓存文件路径"""
|
||||
cache_dir = Path.home() / ".ktransformers" / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "moe_analysis_v2.json"
|
||||
|
||||
|
||||
def _load_all_cache():
|
||||
"""加载所有缓存数据"""
|
||||
cache_file = _get_cache_file()
|
||||
if not cache_file.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_all_cache(cache_data):
|
||||
"""保存所有缓存数据"""
|
||||
cache_file = _get_cache_file()
|
||||
try:
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(cache_data, f, indent=2)
|
||||
except Exception as e:
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to save MoE cache: {e}")
|
||||
|
||||
|
||||
def _compute_config_fingerprint(config_path: Path) -> Optional[str]:
|
||||
"""计算 config.json 指纹"""
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
stat = config_path.stat()
|
||||
# 使用文件大小和修改时间作为指纹
|
||||
fingerprint_str = f"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}"
|
||||
return hashlib.md5(fingerprint_str.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_cache(model_path: Path) -> Optional[Dict[str, Any]]:
|
||||
"""加载指定模型的缓存"""
|
||||
model_path_str = str(model_path.resolve())
|
||||
all_cache = _load_all_cache()
|
||||
|
||||
if model_path_str not in all_cache:
|
||||
return None
|
||||
|
||||
try:
|
||||
cache_entry = all_cache[model_path_str]
|
||||
|
||||
# 验证缓存版本
|
||||
cache_version = cache_entry.get("cache_version", 0)
|
||||
if cache_version != 2:
|
||||
return None
|
||||
|
||||
# 验证 config.json 指纹
|
||||
config_path = model_path / "config.json"
|
||||
current_fingerprint = _compute_config_fingerprint(config_path)
|
||||
if cache_entry.get("fingerprint") != current_fingerprint:
|
||||
return None
|
||||
|
||||
return cache_entry.get("result")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _save_cache(model_path: Path, result: Dict[str, Any]):
|
||||
"""保存指定模型的缓存"""
|
||||
model_path_str = str(model_path.resolve())
|
||||
|
||||
try:
|
||||
config_path = model_path / "config.json"
|
||||
fingerprint = _compute_config_fingerprint(config_path)
|
||||
|
||||
all_cache = _load_all_cache()
|
||||
|
||||
all_cache[model_path_str] = {
|
||||
"fingerprint": fingerprint,
|
||||
"result": result,
|
||||
"cache_version": 2,
|
||||
"last_updated": __import__("datetime").datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
_save_all_cache(all_cache)
|
||||
except Exception as e:
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Failed to save MoE cache for {model_path}: {e}")
|
||||
|
||||
|
||||
def _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]:
|
||||
"""读取 config.json 文件
|
||||
|
||||
参考 sglang 的 get_config() 实现
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _is_moe_model(config: Dict[str, Any]) -> bool:
|
||||
"""判断是否是 MoE 模型
|
||||
|
||||
参考 sglang 的模型注册表和架构识别方式
|
||||
"""
|
||||
# 方法1: 检查架构名称
|
||||
architectures = config.get("architectures", [])
|
||||
if any(arch in MOE_ARCHITECTURES for arch in architectures):
|
||||
return True
|
||||
|
||||
# 方法2: 检查是否有 MoE 相关字段(Mistral 格式)
|
||||
if config.get("moe"):
|
||||
return True
|
||||
|
||||
# 方法3: 检查是否有 num_experts 或其变体字段
|
||||
# 需要检查 text_config(对于某些多模态模型)
|
||||
text_config = config.get("text_config", config)
|
||||
|
||||
# 检查各种专家数量字段
|
||||
if (
|
||||
text_config.get("num_experts") or text_config.get("num_local_experts") or text_config.get("n_routed_experts")
|
||||
): # Kimi-K2 使用这个字段
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从 config 中提取 MoE 参数
|
||||
|
||||
参考 sglang 的各种 MoE 模型实现
|
||||
"""
|
||||
# 处理嵌套的 text_config
|
||||
text_config = config.get("text_config", config)
|
||||
|
||||
# 提取基本参数
|
||||
result = {
|
||||
"architectures": config.get("architectures", []),
|
||||
"model_type": config.get("model_type", "unknown"),
|
||||
}
|
||||
|
||||
# 专家数量(不同模型字段名不同)
|
||||
num_experts = (
|
||||
text_config.get("num_experts") # Qwen2/3 MoE, DeepSeek V2
|
||||
or text_config.get("num_local_experts") # Mixtral
|
||||
or text_config.get("n_routed_experts") # Kimi-K2, DeepSeek V3
|
||||
or config.get("moe", {}).get("num_experts") # Mistral 格式
|
||||
)
|
||||
|
||||
# 每个 token 激活的专家数
|
||||
num_experts_per_tok = (
|
||||
text_config.get("num_experts_per_tok")
|
||||
or text_config.get("num_experts_per_token")
|
||||
or config.get("moe", {}).get("num_experts_per_tok")
|
||||
or 2 # 默认值
|
||||
)
|
||||
|
||||
# 层数
|
||||
num_hidden_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or 0
|
||||
|
||||
# 隐藏层维度
|
||||
hidden_size = text_config.get("hidden_size") or text_config.get("d_model") or 0
|
||||
|
||||
# MoE 专家中间层大小
|
||||
moe_intermediate_size = (
|
||||
text_config.get("moe_intermediate_size")
|
||||
or text_config.get("intermediate_size") # 如果没有特殊的 moe_intermediate_size
|
||||
or 0
|
||||
)
|
||||
|
||||
# 共享专家中间层大小(Qwen2/3 MoE)
|
||||
shared_expert_intermediate_size = text_config.get("shared_expert_intermediate_size", 0)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"num_experts": num_experts or 0,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"hidden_size": hidden_size,
|
||||
"moe_intermediate_size": moe_intermediate_size,
|
||||
"shared_expert_intermediate_size": shared_expert_intermediate_size,
|
||||
}
|
||||
)
|
||||
|
||||
# 提取其他有用的参数
|
||||
result["num_attention_heads"] = text_config.get("num_attention_heads", 0)
|
||||
result["num_key_value_heads"] = text_config.get("num_key_value_heads", 0)
|
||||
result["vocab_size"] = text_config.get("vocab_size", 0)
|
||||
result["max_position_embeddings"] = text_config.get("max_position_embeddings", 0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _estimate_model_size(model_path: Path) -> float:
|
||||
"""估算模型总大小(GB)
|
||||
|
||||
快速统计 safetensors 文件总大小
|
||||
"""
|
||||
try:
|
||||
total_size = 0
|
||||
for file_path in model_path.glob("*.safetensors"):
|
||||
total_size += file_path.stat().st_size
|
||||
return total_size / (1024**3)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def analyze_moe_model(model_path, use_cache=True):
|
||||
"""
|
||||
快速分析 MoE 模型 - 只读取 config.json
|
||||
|
||||
参数:
|
||||
model_path: 模型路径(字符串或Path对象)
|
||||
use_cache: 是否使用缓存(默认True)
|
||||
|
||||
返回:
|
||||
dict: {
|
||||
'is_moe': 是否是 MoE 模型,
|
||||
'num_experts': 专家总数,
|
||||
'num_experts_per_tok': 每个 token 激活的专家数,
|
||||
'num_hidden_layers': 层数,
|
||||
'hidden_size': 隐藏层维度,
|
||||
'moe_intermediate_size': MoE 专家中间层大小,
|
||||
'shared_expert_intermediate_size': 共享专家中间层大小,
|
||||
'architectures': 模型架构列表,
|
||||
'model_type': 模型类型,
|
||||
'total_size_gb': 模型总大小(估算,GB),
|
||||
'cached': 是否从缓存读取
|
||||
}
|
||||
如果不是 MoE 模型或失败,返回 None
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not model_path.exists():
|
||||
return None
|
||||
|
||||
# 尝试加载缓存
|
||||
if use_cache:
|
||||
cached_result = _load_cache(model_path)
|
||||
if cached_result:
|
||||
cached_result["cached"] = True
|
||||
return cached_result
|
||||
|
||||
# 读取 config.json
|
||||
config = _load_config_json(model_path)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
# 判断是否是 MoE 模型
|
||||
if not _is_moe_model(config):
|
||||
return None
|
||||
|
||||
# 提取 MoE 参数
|
||||
params = _extract_moe_params(config)
|
||||
|
||||
# 验证必要参数
|
||||
if params["num_experts"] == 0:
|
||||
return None
|
||||
|
||||
# 估算模型大小
|
||||
total_size_gb = _estimate_model_size(model_path)
|
||||
|
||||
# 组装结果
|
||||
result = {
|
||||
"is_moe": True,
|
||||
"num_experts": params["num_experts"],
|
||||
"num_experts_per_tok": params["num_experts_per_tok"],
|
||||
"num_hidden_layers": params["num_hidden_layers"],
|
||||
"hidden_size": params["hidden_size"],
|
||||
"moe_intermediate_size": params["moe_intermediate_size"],
|
||||
"shared_expert_intermediate_size": params["shared_expert_intermediate_size"],
|
||||
"architectures": params["architectures"],
|
||||
"model_type": params["model_type"],
|
||||
"total_size_gb": total_size_gb,
|
||||
"cached": False,
|
||||
# 额外参数
|
||||
"num_attention_heads": params.get("num_attention_heads", 0),
|
||||
"num_key_value_heads": params.get("num_key_value_heads", 0),
|
||||
"vocab_size": params.get("vocab_size", 0),
|
||||
}
|
||||
|
||||
# 保存缓存
|
||||
if use_cache:
|
||||
_save_cache(model_path, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_analysis(model_path):
|
||||
"""打印模型分析结果"""
|
||||
print(f"分析模型: {model_path}\n")
|
||||
|
||||
result = analyze_moe_model(model_path)
|
||||
|
||||
if result is None:
|
||||
print("不是 MoE 模型或分析失败")
|
||||
return
|
||||
|
||||
print("=" * 70)
|
||||
print("MoE 模型分析结果")
|
||||
if result.get("cached"):
|
||||
print("[使用缓存]")
|
||||
print("=" * 70)
|
||||
print(f"模型架构:")
|
||||
print(f" - 架构: {', '.join(result['architectures'])}")
|
||||
print(f" - 类型: {result['model_type']}")
|
||||
print()
|
||||
print(f"MoE 结构:")
|
||||
print(f" - 专家总数: {result['num_experts']}")
|
||||
print(f" - 激活专家数: {result['num_experts_per_tok']} experts/token")
|
||||
print(f" - 层数: {result['num_hidden_layers']}")
|
||||
print(f" - 隐藏维度: {result['hidden_size']}")
|
||||
print(f" - MoE 中间层: {result['moe_intermediate_size']}")
|
||||
if result["shared_expert_intermediate_size"] > 0:
|
||||
print(f" - 共享专家中间层: {result['shared_expert_intermediate_size']}")
|
||||
print()
|
||||
print(f"大小统计:")
|
||||
print(f" - 模型总大小: {result['total_size_gb']:.2f} GB")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
models = ["/mnt/data2/models/Qwen3-30B-A3B", "/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507"]
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
models = [sys.argv[1]]
|
||||
|
||||
for model_path in models:
|
||||
print_analysis(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
kt-kernel/python/cli/utils/debug_configs.py
Normal file
118
kt-kernel/python/cli/utils/debug_configs.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Debug utility to inspect saved run configurations.
|
||||
|
||||
Usage: python -m kt_kernel.cli.utils.debug_configs
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def main():
|
||||
"""Show all saved configurations."""
|
||||
config_file = Path.home() / ".ktransformers" / "run_configs.yaml"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Configuration file:[/bold] {config_file}")
|
||||
console.print()
|
||||
|
||||
if not config_file.exists():
|
||||
console.print("[red]✗ Configuration file does not exist![/red]")
|
||||
console.print()
|
||||
console.print("No configurations have been saved yet.")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Failed to load configuration file: {e}[/red]")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓[/green] Configuration file loaded")
|
||||
console.print()
|
||||
|
||||
configs = data.get("configs", {})
|
||||
|
||||
if not configs:
|
||||
console.print("[yellow]No saved configurations found.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[bold]Found configurations for {len(configs)} model(s):[/bold]")
|
||||
console.print()
|
||||
|
||||
for model_id, model_configs in configs.items():
|
||||
console.print(f"[cyan]Model ID:[/cyan] {model_id}")
|
||||
console.print(f"[dim] {len(model_configs)} configuration(s)[/dim]")
|
||||
console.print()
|
||||
|
||||
if not model_configs:
|
||||
continue
|
||||
|
||||
# Display configs in a table
|
||||
table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan")
|
||||
table.add_column("#", justify="right", style="cyan")
|
||||
table.add_column("Name", style="white")
|
||||
table.add_column("Method", style="yellow")
|
||||
table.add_column("TP", justify="right", style="green")
|
||||
table.add_column("GPU Experts", justify="right", style="magenta")
|
||||
table.add_column("Created", style="dim")
|
||||
|
||||
for i, cfg in enumerate(model_configs, 1):
|
||||
method = cfg.get("inference_method", "?")
|
||||
kt_method = cfg.get("kt_method", "?")
|
||||
method_display = f"{method.upper()}"
|
||||
if method == "raw":
|
||||
method_display += f" ({cfg.get('raw_method', '?')})"
|
||||
elif method == "amx":
|
||||
method_display += f" ({kt_method})"
|
||||
|
||||
table.add_row(
|
||||
str(i),
|
||||
cfg.get("config_name", f"Config {i}"),
|
||||
method_display,
|
||||
str(cfg.get("tp_size", "?")),
|
||||
str(cfg.get("gpu_experts", "?")),
|
||||
cfg.get("created_at", "Unknown")[:19] if cfg.get("created_at") else "Unknown",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Also check user_models.yaml to show model names
|
||||
console.print("[bold]Checking model registry...[/bold]")
|
||||
console.print()
|
||||
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
try:
|
||||
registry = UserModelRegistry()
|
||||
all_models = registry.list_models()
|
||||
|
||||
console.print(f"[green]✓[/green] Found {len(all_models)} registered model(s)")
|
||||
console.print()
|
||||
|
||||
# Map model IDs to names
|
||||
id_to_name = {m.id: m.name for m in all_models}
|
||||
|
||||
console.print("[bold]Model ID → Name mapping:[/bold]")
|
||||
console.print()
|
||||
|
||||
for model_id in configs.keys():
|
||||
model_name = id_to_name.get(model_id, "[red]Unknown (model not found in registry)[/red]")
|
||||
console.print(f" {model_id[:8]}... → {model_name}")
|
||||
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]⚠ Could not load model registry: {e}[/yellow]")
|
||||
console.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
kt-kernel/python/cli/utils/download_helper.py
Normal file
146
kt-kernel/python/cli/utils/download_helper.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Helper functions for interactive model download."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
import fnmatch
|
||||
|
||||
|
||||
def list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]:
|
||||
"""
|
||||
List files in a HuggingFace repository.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: 'path', 'size' (in bytes)
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
import os
|
||||
|
||||
# Set mirror if needed
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
files_info = api.list_repo_tree(repo_id=repo_id, recursive=True)
|
||||
|
||||
result = []
|
||||
for item in files_info:
|
||||
# Skip directories
|
||||
if hasattr(item, "type") and item.type == "directory":
|
||||
continue
|
||||
|
||||
# Get file info
|
||||
file_path = item.path if hasattr(item, "path") else str(item)
|
||||
file_size = item.size if hasattr(item, "size") else 0
|
||||
|
||||
result.append({"path": file_path, "size": file_size})
|
||||
|
||||
return result
|
||||
finally:
|
||||
# Restore original endpoint
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
|
||||
def list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]:
|
||||
"""
|
||||
List files in a ModelScope repository.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: 'path', 'size' (in bytes)
|
||||
"""
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
api = HubApi()
|
||||
files_info = api.get_model_files(model_id=repo_id, recursive=True)
|
||||
|
||||
result = []
|
||||
for file_info in files_info:
|
||||
file_path = file_info.get("Name", file_info.get("Path", ""))
|
||||
file_size = file_info.get("Size", 0)
|
||||
|
||||
result.append({"path": file_path, "size": file_size})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]:
|
||||
"""Filter files by glob pattern."""
|
||||
if pattern == "*":
|
||||
return files
|
||||
|
||||
filtered = []
|
||||
for file in files:
|
||||
# Check if filename matches pattern
|
||||
filename = Path(file["path"]).name
|
||||
full_path = file["path"]
|
||||
|
||||
if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern):
|
||||
filtered.append(file)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def calculate_total_size(files: List[Dict[str, any]]) -> int:
|
||||
"""Calculate total size of files in bytes."""
|
||||
return sum(f["size"] for f in files)
|
||||
|
||||
|
||||
def format_file_list_table(files: List[Dict[str, any]], max_display: int = 10):
|
||||
"""Format file list as a table for display."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_scanner import format_size
|
||||
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("File", style="cyan", overflow="fold")
|
||||
table.add_column("Size", justify="right")
|
||||
|
||||
# Show first max_display files
|
||||
for file in files[:max_display]:
|
||||
table.add_row(file["path"], format_size(file["size"]))
|
||||
|
||||
if len(files) > max_display:
|
||||
table.add_row(f"... and {len(files) - max_display} more files", "[dim]...[/dim]")
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify if a repository exists.
|
||||
|
||||
Returns:
|
||||
(exists: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
if repo_type == "huggingface":
|
||||
import os
|
||||
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
api.repo_info(repo_id=repo_id, repo_type="model")
|
||||
return True, "Repository found"
|
||||
finally:
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
else: # modelscope
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
api = HubApi()
|
||||
api.get_model(model_id=repo_id)
|
||||
return True, "Repository found"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Repository not found: {str(e)}"
|
||||
216
kt-kernel/python/cli/utils/input_validators.py
Normal file
216
kt-kernel/python/cli/utils/input_validators.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Input validation utilities with retry mechanism.
|
||||
|
||||
Provides robust input validation with automatic retry on failure.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Callable, Any
|
||||
from rich.console import Console
|
||||
from rich.prompt import Prompt
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def prompt_int_with_retry(
|
||||
message: str,
|
||||
default: Optional[int] = None,
|
||||
min_val: Optional[int] = None,
|
||||
max_val: Optional[int] = None,
|
||||
validator: Optional[Callable[[int], bool]] = None,
|
||||
validator_error_msg: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Prompt for integer input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value (optional)
|
||||
min_val: Minimum allowed value (optional)
|
||||
max_val: Maximum allowed value (optional)
|
||||
validator: Custom validation function (optional)
|
||||
validator_error_msg: Error message for custom validator (optional)
|
||||
|
||||
Returns:
|
||||
Validated integer value
|
||||
"""
|
||||
while True:
|
||||
# Build prompt with default
|
||||
if default is not None:
|
||||
prompt_text = f"{message} [{default}]"
|
||||
else:
|
||||
prompt_text = message
|
||||
|
||||
# Get input
|
||||
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
|
||||
|
||||
# Try to parse as integer
|
||||
try:
|
||||
value = int(user_input)
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid input. Please enter a valid integer.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate range
|
||||
if min_val is not None and value < min_val:
|
||||
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
if max_val is not None and value > max_val:
|
||||
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Custom validation
|
||||
if validator is not None:
|
||||
if not validator(value):
|
||||
error_msg = validator_error_msg or "Invalid value"
|
||||
console.print(f"[red]✗ {error_msg}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return value
|
||||
|
||||
|
||||
def prompt_float_with_retry(
|
||||
message: str,
|
||||
default: Optional[float] = None,
|
||||
min_val: Optional[float] = None,
|
||||
max_val: Optional[float] = None,
|
||||
) -> float:
|
||||
"""Prompt for float input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value (optional)
|
||||
min_val: Minimum allowed value (optional)
|
||||
max_val: Maximum allowed value (optional)
|
||||
|
||||
Returns:
|
||||
Validated float value
|
||||
"""
|
||||
while True:
|
||||
# Build prompt with default
|
||||
if default is not None:
|
||||
prompt_text = f"{message} [{default}]"
|
||||
else:
|
||||
prompt_text = message
|
||||
|
||||
# Get input
|
||||
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
|
||||
|
||||
# Try to parse as float
|
||||
try:
|
||||
value = float(user_input)
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid input. Please enter a valid number.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate range
|
||||
if min_val is not None and value < min_val:
|
||||
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
if max_val is not None and value > max_val:
|
||||
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return value
|
||||
|
||||
|
||||
def prompt_choice_with_retry(
|
||||
message: str,
|
||||
choices: List[str],
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Prompt for choice input with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
choices: List of valid choices
|
||||
default: Default choice (optional)
|
||||
|
||||
Returns:
|
||||
Selected choice
|
||||
"""
|
||||
while True:
|
||||
# Get input
|
||||
user_input = Prompt.ask(message, default=default)
|
||||
|
||||
# Validate choice
|
||||
if user_input not in choices:
|
||||
console.print(f"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
return user_input
|
||||
|
||||
|
||||
def prompt_int_list_with_retry(
|
||||
message: str,
|
||||
default: Optional[str] = None,
|
||||
min_val: Optional[int] = None,
|
||||
max_val: Optional[int] = None,
|
||||
validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None,
|
||||
) -> List[int]:
|
||||
"""Prompt for comma-separated integer list with validation and retry.
|
||||
|
||||
Args:
|
||||
message: Prompt message
|
||||
default: Default value as string (e.g., "0,1,2,3")
|
||||
min_val: Minimum allowed value for each integer (optional)
|
||||
max_val: Maximum allowed value for each integer (optional)
|
||||
validator: Custom validation function that returns (is_valid, error_message) (optional)
|
||||
|
||||
Returns:
|
||||
List of validated integers
|
||||
"""
|
||||
while True:
|
||||
# Get input
|
||||
user_input = Prompt.ask(message, default=default)
|
||||
|
||||
# Clean input: support Chinese comma and spaces
|
||||
user_input_cleaned = user_input.replace(",", ",").replace(" ", "")
|
||||
|
||||
# Try to parse as integers
|
||||
try:
|
||||
values = [int(x.strip()) for x in user_input_cleaned.split(",") if x.strip()]
|
||||
except ValueError:
|
||||
console.print(f"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Validate each value's range
|
||||
invalid_values = []
|
||||
for value in values:
|
||||
if min_val is not None and value < min_val:
|
||||
invalid_values.append(value)
|
||||
elif max_val is not None and value > max_val:
|
||||
invalid_values.append(value)
|
||||
|
||||
if invalid_values:
|
||||
if min_val is not None and max_val is not None:
|
||||
console.print(f"[red]✗ Invalid value(s): {invalid_values}[/red]")
|
||||
console.print(f"[yellow]Valid range: {min_val}-{max_val}[/yellow]")
|
||||
elif min_val is not None:
|
||||
console.print(f"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]")
|
||||
elif max_val is not None:
|
||||
console.print(f"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# Custom validation
|
||||
if validator is not None:
|
||||
is_valid, error_msg = validator(values)
|
||||
if not is_valid:
|
||||
console.print(f"[red]✗ {error_msg}[/red]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
# All validations passed
|
||||
return values
|
||||
207
kt-kernel/python/cli/utils/kv_cache_calculator.py
Normal file
207
kt-kernel/python/cli/utils/kv_cache_calculator.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
KV Cache Size Calculator for SGLang
|
||||
|
||||
This script calculates the KV cache size in GB for a given model and number of tokens.
|
||||
It follows the same logic as in sglang/srt/model_executor/model_runner.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
# Add sglang to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
|
||||
def get_dtype_bytes(dtype_str: str) -> int:
|
||||
"""Get the number of bytes for a given dtype string."""
|
||||
dtype_map = {
|
||||
"float32": 4,
|
||||
"float16": 2,
|
||||
"bfloat16": 2,
|
||||
"float8_e4m3fn": 1,
|
||||
"float8_e5m2": 1,
|
||||
"auto": 2, # Usually defaults to bfloat16
|
||||
}
|
||||
return dtype_map.get(dtype_str, 2)
|
||||
|
||||
|
||||
def get_kv_size_gb(
|
||||
model_path: str,
|
||||
max_total_tokens: int,
|
||||
tp: int = 1,
|
||||
dtype: str = "auto",
|
||||
verbose: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Calculate the KV cache size in GB for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
max_total_tokens: Maximum number of tokens to cache
|
||||
tp: Tensor parallelism size
|
||||
dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)
|
||||
verbose: Whether to print detailed information
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing calculation details
|
||||
"""
|
||||
# Load model config
|
||||
model_config = ModelConfig(model_path, dtype=dtype)
|
||||
hf_config = model_config.hf_config
|
||||
|
||||
# Determine dtype bytes
|
||||
dtype_bytes = get_dtype_bytes(dtype)
|
||||
if dtype == "auto":
|
||||
# Auto dtype usually becomes bfloat16
|
||||
dtype_bytes = 2
|
||||
|
||||
# Number of layers
|
||||
num_layers = model_config.num_attention_layers
|
||||
|
||||
# Check if it's MLA (Multi-head Latent Attention) model
|
||||
is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA"
|
||||
|
||||
result = {
|
||||
"model_path": model_path,
|
||||
"max_total_tokens": max_total_tokens,
|
||||
"tp": tp,
|
||||
"dtype": dtype,
|
||||
"dtype_bytes": dtype_bytes,
|
||||
"num_layers": num_layers,
|
||||
"is_mla": is_mla,
|
||||
}
|
||||
|
||||
if is_mla:
|
||||
# MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)
|
||||
kv_lora_rank = model_config.kv_lora_rank
|
||||
qk_rope_head_dim = model_config.qk_rope_head_dim
|
||||
|
||||
# Calculate cell size (per token)
|
||||
cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes
|
||||
|
||||
result.update(
|
||||
{
|
||||
"kv_lora_rank": kv_lora_rank,
|
||||
"qk_rope_head_dim": qk_rope_head_dim,
|
||||
"cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Check if it's NSA (Native Sparse Attention) model
|
||||
if is_deepseek_nsa(hf_config):
|
||||
index_head_dim = get_nsa_index_head_dim(hf_config)
|
||||
indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
||||
indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)
|
||||
indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes
|
||||
cell_size += indexer_cell_size
|
||||
|
||||
result.update(
|
||||
{
|
||||
"is_nsa": True,
|
||||
"index_head_dim": index_head_dim,
|
||||
"indexer_cell_size_bytes": indexer_cell_size,
|
||||
"total_cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
else:
|
||||
result["is_nsa"] = False
|
||||
else:
|
||||
# Standard MHA models
|
||||
num_kv_heads = model_config.get_num_kv_heads(tp)
|
||||
head_dim = model_config.head_dim
|
||||
v_head_dim = model_config.v_head_dim
|
||||
|
||||
# Calculate cell size (per token)
|
||||
cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes
|
||||
|
||||
result.update(
|
||||
{
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"v_head_dim": v_head_dim,
|
||||
"cell_size_bytes": cell_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate total KV cache size
|
||||
total_size_bytes = max_total_tokens * cell_size
|
||||
total_size_gb = total_size_bytes / (1024**3)
|
||||
|
||||
# For MHA models with separate K and V buffers
|
||||
if not is_mla:
|
||||
k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes
|
||||
v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes
|
||||
k_size_gb = k_size_bytes / (1024**3)
|
||||
v_size_gb = v_size_bytes / (1024**3)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"k_size_gb": k_size_gb,
|
||||
"v_size_gb": v_size_gb,
|
||||
}
|
||||
)
|
||||
|
||||
result.update(
|
||||
{
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_gb": total_size_gb,
|
||||
}
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}")
|
||||
print(f"Architecture: {'MLA' if is_mla else 'MHA'}")
|
||||
print(f"Layers: {num_layers}")
|
||||
|
||||
if is_mla:
|
||||
print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}")
|
||||
if result.get("is_nsa"):
|
||||
print(f"NSA Index Head Dim: {index_head_dim}")
|
||||
print(
|
||||
f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})"
|
||||
)
|
||||
else:
|
||||
print(f"Cell size: {cell_size} bytes")
|
||||
else:
|
||||
print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}")
|
||||
print(f"Cell size: {cell_size} bytes")
|
||||
print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB")
|
||||
|
||||
print(f"Total KV Cache Size: {total_size_gb:.2f} GB")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Calculate KV cache size for a model")
|
||||
parser.add_argument("model_path", help="Path to the model")
|
||||
parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
|
||||
parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)")
|
||||
parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = get_kv_size_gb(
|
||||
args.model_path,
|
||||
args.max_total_tokens,
|
||||
tp=args.tp,
|
||||
dtype=args.dtype,
|
||||
verbose=not args.quiet,
|
||||
)
|
||||
|
||||
if args.quiet:
|
||||
print(f"{result['total_size_gb']:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
250
kt-kernel/python/cli/utils/model_discovery.py
Normal file
250
kt-kernel/python/cli/utils/model_discovery.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Model Discovery Utilities
|
||||
|
||||
Shared functions for discovering and registering new models across different commands.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
|
||||
from kt_kernel.cli.utils.model_scanner import (
|
||||
discover_models,
|
||||
scan_directory_for_models,
|
||||
ScannedModel,
|
||||
)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def discover_and_register_global(
|
||||
min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = "en"
|
||||
) -> Tuple[int, int, List[UserModel]]:
|
||||
"""
|
||||
Perform global model discovery and register new models.
|
||||
|
||||
Args:
|
||||
min_size_gb: Minimum model size in GB
|
||||
max_depth: Maximum search depth
|
||||
show_progress: Whether to show progress messages
|
||||
lang: Language for messages ("en" or "zh")
|
||||
|
||||
Returns:
|
||||
Tuple of (total_found, new_found, registered_models)
|
||||
"""
|
||||
registry = UserModelRegistry()
|
||||
|
||||
if show_progress:
|
||||
if lang == "zh":
|
||||
console.print("[dim]正在扫描系统中的模型权重,这可能需要30-60秒...[/dim]")
|
||||
else:
|
||||
console.print("[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]")
|
||||
|
||||
# Global scan
|
||||
all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth)
|
||||
|
||||
# Filter out existing models
|
||||
new_models = []
|
||||
for model in all_models:
|
||||
if not registry.find_by_path(model.path):
|
||||
new_models.append(model)
|
||||
|
||||
# Register new models
|
||||
registered = []
|
||||
for model in new_models:
|
||||
user_model = _create_and_register_model(registry, model)
|
||||
if user_model:
|
||||
registered.append(user_model)
|
||||
|
||||
return len(all_models), len(new_models), registered
|
||||
|
||||
|
||||
def discover_and_register_path(
|
||||
path: str,
|
||||
min_size_gb: float = 2.0,
|
||||
existing_paths: Optional[set] = None,
|
||||
show_progress: bool = True,
|
||||
lang: str = "en",
|
||||
) -> Tuple[int, int, List[UserModel]]:
|
||||
"""
|
||||
Discover models in a specific path and register new ones.
|
||||
|
||||
Args:
|
||||
path: Directory path to scan
|
||||
min_size_gb: Minimum model file size in GB
|
||||
existing_paths: Set of already discovered paths in this session (optional)
|
||||
show_progress: Whether to show progress messages
|
||||
lang: Language for messages ("en" or "zh")
|
||||
|
||||
Returns:
|
||||
Tuple of (total_found, new_found, registered_models)
|
||||
"""
|
||||
registry = UserModelRegistry()
|
||||
|
||||
if show_progress:
|
||||
if lang == "zh":
|
||||
console.print(f"[dim]正在扫描 {path}...[/dim]")
|
||||
else:
|
||||
console.print(f"[dim]Scanning {path}...[/dim]")
|
||||
|
||||
# Scan directory
|
||||
model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb)
|
||||
|
||||
if not model_info:
|
||||
return 0, 0, []
|
||||
|
||||
# Convert to ScannedModel and filter
|
||||
new_models = []
|
||||
for dir_path, (format_type, size_bytes, file_count, files) in model_info.items():
|
||||
# Check if already in registry
|
||||
if registry.find_by_path(dir_path):
|
||||
continue
|
||||
|
||||
# Check if already discovered in this session
|
||||
if existing_paths and dir_path in existing_paths:
|
||||
continue
|
||||
|
||||
model = ScannedModel(
|
||||
path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files
|
||||
)
|
||||
new_models.append(model)
|
||||
|
||||
# Register new models
|
||||
registered = []
|
||||
for model in new_models:
|
||||
user_model = _create_and_register_model(registry, model)
|
||||
if user_model:
|
||||
registered.append(user_model)
|
||||
|
||||
return len(model_info), len(new_models), registered
|
||||
|
||||
|
||||
def _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]:
|
||||
"""
|
||||
Create a UserModel from ScannedModel and register it.
|
||||
|
||||
Handles name conflicts by suggesting a unique name (e.g., model-2, model-3).
|
||||
Automatically detects repo_id from README.md YAML frontmatter.
|
||||
Automatically detects and caches MoE information for safetensors models.
|
||||
|
||||
Args:
|
||||
registry: UserModelRegistry instance
|
||||
scanned_model: ScannedModel to register
|
||||
|
||||
Returns:
|
||||
Registered UserModel or None if failed
|
||||
"""
|
||||
# Use suggest_name to get a unique name (adds -2, -3, etc. if needed)
|
||||
unique_name = registry.suggest_name(scanned_model.folder_name)
|
||||
|
||||
user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format)
|
||||
|
||||
# Auto-detect repo_id from README.md (only YAML frontmatter)
|
||||
try:
|
||||
from kt_kernel.cli.utils.repo_detector import detect_repo_for_model
|
||||
|
||||
repo_info = detect_repo_for_model(scanned_model.path)
|
||||
if repo_info:
|
||||
repo_id, repo_type = repo_info
|
||||
user_model.repo_id = repo_id
|
||||
user_model.repo_type = repo_type
|
||||
except Exception:
|
||||
# Silently continue if detection fails
|
||||
pass
|
||||
|
||||
# Auto-detect MoE information for safetensors models
|
||||
if scanned_model.format == "safetensors":
|
||||
try:
|
||||
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
|
||||
|
||||
moe_result = analyze_moe_model(scanned_model.path, use_cache=True)
|
||||
if moe_result and moe_result.get("is_moe"):
|
||||
user_model.is_moe = True
|
||||
user_model.moe_num_experts = moe_result.get("num_experts")
|
||||
user_model.moe_num_experts_per_tok = moe_result.get("num_experts_per_tok")
|
||||
else:
|
||||
user_model.is_moe = False
|
||||
except Exception:
|
||||
# Silently continue if MoE detection fails
|
||||
# is_moe will remain None
|
||||
pass
|
||||
|
||||
try:
|
||||
registry.add_model(user_model)
|
||||
return user_model
|
||||
except Exception:
|
||||
# Should not happen since we used suggest_name, but handle gracefully
|
||||
return None
|
||||
|
||||
|
||||
def format_discovery_summary(
|
||||
total_found: int,
|
||||
new_found: int,
|
||||
registered: List[UserModel],
|
||||
lang: str = "en",
|
||||
show_models: bool = True,
|
||||
max_show: int = 10,
|
||||
) -> None:
|
||||
"""
|
||||
Print formatted discovery summary.
|
||||
|
||||
Args:
|
||||
total_found: Total models found
|
||||
new_found: New models found
|
||||
registered: List of registered UserModel objects
|
||||
lang: Language ("en" or "zh")
|
||||
show_models: Whether to show model list
|
||||
max_show: Maximum models to show
|
||||
"""
|
||||
console.print()
|
||||
|
||||
if new_found == 0:
|
||||
if total_found > 0:
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,所有模型均已在列表中")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, all already in the list")
|
||||
else:
|
||||
if lang == "zh":
|
||||
console.print("[yellow]未找到模型[/yellow]")
|
||||
else:
|
||||
console.print("[yellow]No models found[/yellow]")
|
||||
return
|
||||
|
||||
# Show summary
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,其中 {new_found} 个为新模型")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new")
|
||||
|
||||
# Show registered count
|
||||
if len(registered) > 0:
|
||||
if lang == "zh":
|
||||
console.print(f"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表")
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Successfully added {len(registered)} new models to list")
|
||||
|
||||
# Show model list
|
||||
if show_models and registered:
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print(f"[dim]新发现的模型(前{max_show}个):[/dim]")
|
||||
else:
|
||||
console.print(f"[dim]Newly discovered models (first {max_show}):[/dim]")
|
||||
|
||||
for i, model in enumerate(registered[:max_show], 1):
|
||||
# Get size from registry or estimate
|
||||
size_str = "?.? GB"
|
||||
# Try to find the ScannedModel to get size
|
||||
# For now just show name and path
|
||||
console.print(f" {i}. {model.name} ({model.format})")
|
||||
console.print(f" [dim]{model.path}[/dim]")
|
||||
|
||||
if len(registered) > max_show:
|
||||
remaining = len(registered) - max_show
|
||||
if lang == "zh":
|
||||
console.print(f" [dim]... 还有 {remaining} 个新模型[/dim]")
|
||||
else:
|
||||
console.print(f" [dim]... and {remaining} more new models[/dim]")
|
||||
790
kt-kernel/python/cli/utils/model_scanner.py
Normal file
790
kt-kernel/python/cli/utils/model_scanner.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""
|
||||
Model Scanner
|
||||
|
||||
Scans directories for model files (safetensors, gguf) and identifies models
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Tuple, Dict
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScannedModel:
|
||||
"""Temporary structure for scanned model information"""
|
||||
|
||||
path: str # Absolute path to model directory
|
||||
format: str # "safetensors" | "gguf" | "mixed"
|
||||
size_bytes: int # Total size in bytes
|
||||
file_count: int # Number of model files
|
||||
files: List[str] # List of model file names
|
||||
|
||||
@property
|
||||
def size_gb(self) -> float:
|
||||
"""Get size in GB"""
|
||||
return self.size_bytes / (1024**3)
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
"""Get the folder name (default model name)"""
|
||||
return Path(self.path).name
|
||||
|
||||
|
||||
class ModelScanner:
|
||||
"""Scanner for discovering models in directory trees"""
|
||||
|
||||
def __init__(self, min_size_gb: float = 10.0):
|
||||
"""
|
||||
Initialize scanner
|
||||
|
||||
Args:
|
||||
min_size_gb: Minimum folder size in GB to be considered a model
|
||||
"""
|
||||
self.min_size_bytes = int(min_size_gb * 1024**3)
|
||||
|
||||
def scan_directory(
|
||||
self, base_path: Path, exclude_paths: Optional[Set[str]] = None
|
||||
) -> Tuple[List[ScannedModel], List[str]]:
|
||||
"""
|
||||
Scan directory tree for models
|
||||
|
||||
Args:
|
||||
base_path: Root directory to scan
|
||||
exclude_paths: Set of absolute paths to exclude from results
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_models, warnings)
|
||||
- valid_models: List of ScannedModel instances
|
||||
- warnings: List of warning messages
|
||||
"""
|
||||
if not base_path.exists():
|
||||
raise ValueError(f"Path does not exist: {base_path}")
|
||||
|
||||
if not base_path.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {base_path}")
|
||||
|
||||
exclude_paths = exclude_paths or set()
|
||||
results: List[ScannedModel] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
# Walk the directory tree
|
||||
for root, dirs, files in os.walk(base_path):
|
||||
root_path = Path(root).resolve()
|
||||
|
||||
# Skip if already registered
|
||||
if str(root_path) in exclude_paths:
|
||||
dirs[:] = [] # Don't descend into this directory
|
||||
continue
|
||||
|
||||
# Check for model files
|
||||
safetensors_files = [f for f in files if f.endswith(".safetensors")]
|
||||
gguf_files = [f for f in files if f.endswith(".gguf")]
|
||||
|
||||
if not safetensors_files and not gguf_files:
|
||||
continue # No model files in this directory
|
||||
|
||||
# Calculate total size
|
||||
model_files = safetensors_files + gguf_files
|
||||
total_size = self._calculate_total_size(root_path, model_files)
|
||||
|
||||
# Check if size meets minimum threshold
|
||||
if total_size < self.min_size_bytes:
|
||||
continue # Too small, but keep scanning subdirectories
|
||||
|
||||
# Detect format
|
||||
if safetensors_files and gguf_files:
|
||||
# Mixed format - issue warning
|
||||
warnings.append(
|
||||
f"Mixed format detected in {root_path}: "
|
||||
f"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. "
|
||||
"Please separate into different folders and re-scan."
|
||||
)
|
||||
dirs[:] = [] # Don't descend into mixed format directories
|
||||
continue
|
||||
|
||||
# Determine format
|
||||
format_type = "safetensors" if safetensors_files else "gguf"
|
||||
|
||||
# Create scanned model
|
||||
scanned = ScannedModel(
|
||||
path=str(root_path),
|
||||
format=format_type,
|
||||
size_bytes=total_size,
|
||||
file_count=len(model_files),
|
||||
files=model_files,
|
||||
)
|
||||
|
||||
results.append(scanned)
|
||||
|
||||
# Continue scanning subdirectories - they might also contain models
|
||||
# Each subdirectory will be independently checked for size >= 10GB
|
||||
|
||||
return results, warnings
|
||||
|
||||
def scan_single_path(self, path: Path) -> Optional[ScannedModel]:
|
||||
"""
|
||||
Scan a single path for model files
|
||||
|
||||
Args:
|
||||
path: Path to scan
|
||||
|
||||
Returns:
|
||||
ScannedModel instance or None if not a valid model
|
||||
"""
|
||||
if not path.exists() or not path.is_dir():
|
||||
return None
|
||||
|
||||
# Find model files
|
||||
safetensors_files = list(path.glob("*.safetensors"))
|
||||
gguf_files = list(path.glob("*.gguf"))
|
||||
|
||||
if not safetensors_files and not gguf_files:
|
||||
return None
|
||||
|
||||
# Check for mixed format
|
||||
if safetensors_files and gguf_files:
|
||||
raise ValueError(
|
||||
f"Mixed format detected: {len(safetensors_files)} safetensors + "
|
||||
f"{len(gguf_files)} gguf files. Please use a single format."
|
||||
)
|
||||
|
||||
# Calculate size
|
||||
model_files = [f.name for f in safetensors_files + gguf_files]
|
||||
total_size = self._calculate_total_size(path, model_files)
|
||||
|
||||
# Determine format
|
||||
format_type = "safetensors" if safetensors_files else "gguf"
|
||||
|
||||
return ScannedModel(
|
||||
path=str(path.resolve()),
|
||||
format=format_type,
|
||||
size_bytes=total_size,
|
||||
file_count=len(model_files),
|
||||
files=model_files,
|
||||
)
|
||||
|
||||
def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int:
|
||||
"""
|
||||
Calculate total size of specified files in directory
|
||||
|
||||
Args:
|
||||
directory: Directory containing the files
|
||||
filenames: List of filenames to sum
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
total = 0
|
||||
for filename in filenames:
|
||||
file_path = directory / filename
|
||||
if file_path.exists():
|
||||
try:
|
||||
total += file_path.stat().st_size
|
||||
except OSError:
|
||||
# File might be inaccessible, skip it
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def scan_directory(
|
||||
base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None
|
||||
) -> Tuple[List[ScannedModel], List[str]]:
|
||||
"""
|
||||
Convenience function to scan a directory
|
||||
|
||||
Args:
|
||||
base_path: Root directory to scan
|
||||
min_size_gb: Minimum folder size in GB
|
||||
exclude_paths: Set of paths to exclude
|
||||
|
||||
Returns:
|
||||
Tuple of (models, warnings)
|
||||
"""
|
||||
scanner = ModelScanner(min_size_gb=min_size_gb)
|
||||
return scanner.scan_directory(base_path, exclude_paths)
|
||||
|
||||
|
||||
def scan_single_path(path: Path) -> Optional[ScannedModel]:
|
||||
"""
|
||||
Convenience function to scan a single path
|
||||
|
||||
Args:
|
||||
path: Path to scan
|
||||
|
||||
Returns:
|
||||
ScannedModel or None
|
||||
"""
|
||||
scanner = ModelScanner()
|
||||
return scanner.scan_single_path(path)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes to human-readable string
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "42.3 GB")
|
||||
"""
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
|
||||
# ===== Fast Scanning with Find Command and Tree-based Root Detection =====
|
||||
|
||||
|
||||
def find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]:
|
||||
"""
|
||||
Use find command to quickly locate files
|
||||
|
||||
Args:
|
||||
mount_point: Starting directory
|
||||
pattern: File pattern (e.g., "config.json", "*.gguf")
|
||||
max_depth: Maximum directory depth (default: 6)
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of absolute file paths
|
||||
"""
|
||||
try:
|
||||
# Use shell=True to redirect stderr to /dev/null, ignoring permission errors
|
||||
result = subprocess.run(
|
||||
f'find "{mount_point}" -maxdepth {max_depth} -name "{pattern}" -type f 2>/dev/null',
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
shell=True,
|
||||
)
|
||||
|
||||
# Return results even if returncode is non-zero (due to permission errors)
|
||||
# As long as we got some output
|
||||
if result.stdout:
|
||||
return [line.strip() for line in result.stdout.strip().split("\n") if line.strip()]
|
||||
return []
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return []
|
||||
|
||||
|
||||
def is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if a directory is a valid model directory
|
||||
|
||||
Args:
|
||||
directory: Path to check
|
||||
min_size_gb: Minimum size in GB
|
||||
|
||||
Returns:
|
||||
(is_valid, model_type) where model_type is "safetensors", "gguf", or None
|
||||
"""
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
return False, None
|
||||
|
||||
has_config = (directory / "config.json").exists()
|
||||
safetensors_files = list(directory.glob("*.safetensors"))
|
||||
gguf_files = list(directory.glob("*.gguf"))
|
||||
|
||||
# Determine model type
|
||||
model_type = None
|
||||
if (has_config and safetensors_files) or safetensors_files:
|
||||
model_type = "safetensors"
|
||||
elif gguf_files:
|
||||
model_type = "gguf"
|
||||
else:
|
||||
return False, None
|
||||
|
||||
# Check size - only count model files (fast!)
|
||||
total_size = 0
|
||||
if model_type == "safetensors":
|
||||
for f in safetensors_files:
|
||||
try:
|
||||
total_size += f.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
else: # gguf
|
||||
for f in gguf_files:
|
||||
try:
|
||||
total_size += f.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
size_gb = total_size / (1024**3)
|
||||
if size_gb < min_size_gb:
|
||||
return False, None
|
||||
|
||||
return True, model_type
|
||||
|
||||
|
||||
def scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]:
|
||||
"""
|
||||
Fast scan for all model paths using find command
|
||||
|
||||
Args:
|
||||
mount_points: List of mount points to scan
|
||||
min_size_gb: Minimum model size in GB
|
||||
max_depth: Maximum search depth (default: 6)
|
||||
|
||||
Returns:
|
||||
List of valid model directory paths
|
||||
"""
|
||||
model_paths = set()
|
||||
|
||||
for mount in mount_points:
|
||||
if not os.path.exists(mount):
|
||||
continue
|
||||
|
||||
# Find all config.json files
|
||||
config_files = find_files_fast(mount, "config.json", max_depth=max_depth)
|
||||
for config_path in config_files:
|
||||
model_dir = Path(config_path).parent
|
||||
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
|
||||
if is_valid:
|
||||
model_paths.add(str(model_dir.resolve()))
|
||||
|
||||
# Find all *.gguf files
|
||||
gguf_files = find_files_fast(mount, "*.gguf", max_depth=max_depth)
|
||||
for gguf_path in gguf_files:
|
||||
model_dir = Path(gguf_path).parent
|
||||
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
|
||||
if is_valid:
|
||||
model_paths.add(str(model_dir.resolve()))
|
||||
|
||||
return sorted(model_paths)
|
||||
|
||||
|
||||
def get_root_subdirs() -> List[str]:
|
||||
"""
|
||||
Get subdirectories of / that are worth scanning
|
||||
|
||||
Filters out system paths only
|
||||
|
||||
Returns:
|
||||
List of directories to scan
|
||||
"""
|
||||
# System paths to exclude
|
||||
excluded = {
|
||||
"dev",
|
||||
"proc",
|
||||
"sys",
|
||||
"run",
|
||||
"boot",
|
||||
"tmp",
|
||||
"usr",
|
||||
"lib",
|
||||
"lib64",
|
||||
"bin",
|
||||
"sbin",
|
||||
"etc",
|
||||
"opt",
|
||||
"var",
|
||||
"snap",
|
||||
}
|
||||
|
||||
scan_dirs = []
|
||||
|
||||
try:
|
||||
for entry in os.scandir("/"):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Skip excluded paths
|
||||
if entry.name in excluded:
|
||||
continue
|
||||
|
||||
scan_dirs.append(entry.path)
|
||||
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return sorted(scan_dirs)
|
||||
|
||||
|
||||
def scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]:
|
||||
"""
|
||||
Scan a directory for models using find command with size filter
|
||||
|
||||
Uses find -size +2G to only locate large model files (>=2GB)
|
||||
|
||||
Args:
|
||||
directory: Directory to scan
|
||||
min_file_size_gb: Minimum individual file size in GB (default: 2.0)
|
||||
|
||||
Returns:
|
||||
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
|
||||
"""
|
||||
model_info = {}
|
||||
|
||||
# Convert GB to find's format (e.g., 2GB = +2G)
|
||||
if min_file_size_gb >= 1.0:
|
||||
size_filter = f"+{int(min_file_size_gb)}G"
|
||||
else:
|
||||
size_mb = int(min_file_size_gb * 1024)
|
||||
size_filter = f"+{size_mb}M"
|
||||
|
||||
# 1. Find *.gguf files >= 2GB
|
||||
gguf_cmd = f'find "{directory}" -name "*.gguf" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
|
||||
result = subprocess.run(gguf_cmd, shell=True, capture_output=True, text=True, timeout=120)
|
||||
|
||||
# Group by directory
|
||||
gguf_dirs = defaultdict(list)
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("\t")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
file_path, size_str = parts
|
||||
file_path_obj = Path(file_path)
|
||||
dir_path = str(file_path_obj.parent)
|
||||
gguf_dirs[dir_path].append((file_path_obj.name, int(size_str)))
|
||||
|
||||
# Add all gguf directories
|
||||
for dir_path, files in gguf_dirs.items():
|
||||
total_size = sum(size for _, size in files)
|
||||
model_info[dir_path] = ("gguf", total_size, len(files), [name for name, _ in files])
|
||||
|
||||
# 2. Find *.safetensors files >= 2GB
|
||||
safetensors_cmd = (
|
||||
f'find "{directory}" -name "*.safetensors" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
|
||||
)
|
||||
result = subprocess.run(safetensors_cmd, shell=True, capture_output=True, text=True, timeout=120)
|
||||
|
||||
# Group by directory
|
||||
safetensors_dirs = defaultdict(list)
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("\t")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
file_path, size_str = parts
|
||||
file_path_obj = Path(file_path)
|
||||
dir_path = str(file_path_obj.parent)
|
||||
safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str)))
|
||||
|
||||
# 3. Check each safetensors directory for config.json
|
||||
for dir_path, files in safetensors_dirs.items():
|
||||
if os.path.exists(os.path.join(dir_path, "config.json")):
|
||||
total_size = sum(size for _, size in files)
|
||||
model_info[dir_path] = ("safetensors", total_size, len(files), [name for name, _ in files])
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def scan_all_models_with_info(
|
||||
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
|
||||
) -> Dict[str, tuple]:
|
||||
"""
|
||||
Fast scan with parallel directory scanning
|
||||
|
||||
Strategy:
|
||||
1. Use provided directories or auto-detect root subdirectories
|
||||
2. Scan each directory in parallel (one thread per directory)
|
||||
3. Use find -size +2G to find large model files (>=2GB)
|
||||
|
||||
Args:
|
||||
mount_points: Specific directories to scan, or None to auto-detect from / subdirs
|
||||
min_size_gb: Not used anymore (kept for API compatibility)
|
||||
max_depth: Not used anymore (kept for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Get directories to scan
|
||||
if mount_points is None:
|
||||
# Get root subdirectories (exclude system paths)
|
||||
scan_dirs = get_root_subdirs()
|
||||
else:
|
||||
scan_dirs = mount_points
|
||||
|
||||
if not scan_dirs:
|
||||
return {}
|
||||
|
||||
model_info = {}
|
||||
|
||||
# Scan each directory in parallel (max 8 concurrent)
|
||||
# Use 2GB threshold to find model files
|
||||
with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor:
|
||||
futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs}
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
dir_results = future.result()
|
||||
model_info.update(dir_results)
|
||||
except Exception as e:
|
||||
# Skip directories with errors
|
||||
pass
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]:
|
||||
"""
|
||||
Find optimal root paths from model paths using tree-based algorithm
|
||||
|
||||
Algorithm:
|
||||
1. Build path tree with all intermediate paths
|
||||
2. DFS to calculate f(x) = subtree sum (number of models in subtree)
|
||||
3. Find roots where f(parent) = f(x) > max(f(children))
|
||||
|
||||
Args:
|
||||
model_paths: List of model directory paths
|
||||
|
||||
Returns:
|
||||
(root_paths, subtree_sizes) where:
|
||||
- root_paths: List of inferred root directories
|
||||
- subtree_sizes: Dict mapping each root to number of models
|
||||
"""
|
||||
if not model_paths:
|
||||
return [], {}
|
||||
|
||||
# 1. Build path set (including all intermediate paths)
|
||||
all_paths = set()
|
||||
model_set = set(model_paths)
|
||||
|
||||
for model_path in model_paths:
|
||||
path = Path(model_path)
|
||||
for i in range(1, len(path.parts) + 1):
|
||||
all_paths.add(str(Path(*path.parts[:i])))
|
||||
|
||||
# 2. Build parent-child relationships
|
||||
children_map = defaultdict(list)
|
||||
for path in all_paths:
|
||||
path_obj = Path(path)
|
||||
if len(path_obj.parts) > 1:
|
||||
parent = str(path_obj.parent)
|
||||
if parent in all_paths:
|
||||
children_map[parent].append(path)
|
||||
|
||||
# 3. DFS to calculate f(x) and max_child_f(x)
|
||||
f = {} # path -> subtree sum
|
||||
max_child_f = {} # path -> max(f(children))
|
||||
visited = set()
|
||||
|
||||
def dfs(path: str) -> int:
|
||||
if path in visited:
|
||||
return f[path]
|
||||
visited.add(path)
|
||||
|
||||
# Current node weight (1 if it's a model path, 0 otherwise)
|
||||
weight = 1 if path in model_set else 0
|
||||
|
||||
# Recursively calculate children
|
||||
children = children_map.get(path, [])
|
||||
if not children:
|
||||
# Leaf node
|
||||
f[path] = weight
|
||||
max_child_f[path] = 0
|
||||
return weight
|
||||
|
||||
# Calculate f values for all children
|
||||
children_f_values = [dfs(child) for child in children]
|
||||
|
||||
# Calculate f(x) and max_child_f(x)
|
||||
f[path] = weight + sum(children_f_values)
|
||||
max_child_f[path] = max(children_f_values) if children_f_values else 0
|
||||
|
||||
return f[path]
|
||||
|
||||
# Find top-level nodes (no parent in all_paths)
|
||||
top_nodes = []
|
||||
for path in all_paths:
|
||||
parent = str(Path(path).parent)
|
||||
if parent not in all_paths or parent == path:
|
||||
top_nodes.append(path)
|
||||
|
||||
# Execute DFS from all top nodes
|
||||
for top in top_nodes:
|
||||
dfs(top)
|
||||
|
||||
# 4. Find root nodes: f(parent) = f(x) >= max(f(children))
|
||||
# Note: Use >= instead of > to handle the case where a directory contains only one model
|
||||
candidate_roots = []
|
||||
for path in all_paths:
|
||||
# Skip model paths themselves (leaf nodes in model tree)
|
||||
if path in model_set:
|
||||
continue
|
||||
|
||||
parent = str(Path(path).parent)
|
||||
|
||||
# Check condition: f(parent) = f(x) and f(x) >= max(f(children))
|
||||
if parent in f and f.get(parent, 0) == f.get(path, 0):
|
||||
if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0:
|
||||
candidate_roots.append(path)
|
||||
|
||||
# 5. Remove redundant roots (prefer deeper paths)
|
||||
# If a root is an ancestor of another root with the same f value, remove it
|
||||
roots = []
|
||||
candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts))
|
||||
|
||||
for root in candidate_roots_sorted:
|
||||
# Check if this root is a parent of any already selected root
|
||||
is_redundant = False
|
||||
for selected in roots:
|
||||
if selected.startswith(root + "/"):
|
||||
# selected is a child of root
|
||||
# Only keep root if it has more models (shouldn't happen by algorithm)
|
||||
if f.get(root, 0) == f.get(selected, 0):
|
||||
is_redundant = True
|
||||
break
|
||||
|
||||
if not is_redundant:
|
||||
# Also filter out very shallow paths (< 3 levels)
|
||||
if len(Path(root).parts) >= 3:
|
||||
roots.append(root)
|
||||
|
||||
# Build subtree sizes for roots
|
||||
subtree_sizes = {root: f.get(root, 0) for root in roots}
|
||||
|
||||
return sorted(roots), subtree_sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRootInfo:
|
||||
"""Information about a detected model root path"""
|
||||
|
||||
path: str
|
||||
model_count: int
|
||||
models: List[ScannedModel]
|
||||
|
||||
|
||||
def discover_models(
|
||||
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
|
||||
) -> List[ScannedModel]:
|
||||
"""
|
||||
Discover all model directories on the system
|
||||
|
||||
Fast scan using find command to locate all models that meet the criteria
|
||||
|
||||
Args:
|
||||
mount_points: List of mount points to scan (None = auto-detect)
|
||||
min_size_gb: Minimum model size in GB (default: 10.0)
|
||||
max_depth: Maximum search depth (default: 6)
|
||||
|
||||
Returns:
|
||||
List of ScannedModel sorted by path
|
||||
"""
|
||||
# Auto-detect mount points if not provided
|
||||
if mount_points is None:
|
||||
mount_points = _get_mount_points()
|
||||
|
||||
# Fast scan with cached info (only scan once!)
|
||||
model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth)
|
||||
|
||||
if not model_info:
|
||||
return []
|
||||
|
||||
# Convert to ScannedModel objects
|
||||
results = []
|
||||
for model_path, (model_type, total_size, file_count, files) in model_info.items():
|
||||
results.append(
|
||||
ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files)
|
||||
)
|
||||
|
||||
# Sort by path
|
||||
results.sort(key=lambda m: m.path)
|
||||
return results
|
||||
|
||||
|
||||
def _get_mount_points() -> List[str]:
|
||||
"""
|
||||
Get all valid mount points from /proc/mounts, filtering out system paths
|
||||
|
||||
Returns:
|
||||
List of mount point paths suitable for model storage
|
||||
(excludes root "/" to avoid scanning entire filesystem)
|
||||
"""
|
||||
mount_points = set()
|
||||
|
||||
# System paths to exclude (unlikely to contain model files)
|
||||
excluded_paths = [
|
||||
"/snap/",
|
||||
"/proc/",
|
||||
"/sys/",
|
||||
"/run/",
|
||||
"/boot",
|
||||
"/dev/",
|
||||
"/usr",
|
||||
"/lib",
|
||||
"/lib64",
|
||||
"/bin",
|
||||
"/sbin",
|
||||
"/etc",
|
||||
"/opt",
|
||||
"/var",
|
||||
"/tmp",
|
||||
]
|
||||
|
||||
try:
|
||||
with open("/proc/mounts", "r") as f:
|
||||
for line in f:
|
||||
parts = line.split()
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
|
||||
device, mount_point, fs_type = parts[0], parts[1], parts[2]
|
||||
|
||||
# Filter out pseudo filesystems
|
||||
pseudo_fs = {
|
||||
"proc",
|
||||
"sysfs",
|
||||
"devpts",
|
||||
"tmpfs",
|
||||
"devtmpfs",
|
||||
"cgroup",
|
||||
"cgroup2",
|
||||
"pstore",
|
||||
"bpf",
|
||||
"tracefs",
|
||||
"debugfs",
|
||||
"hugetlbfs",
|
||||
"mqueue",
|
||||
"configfs",
|
||||
"securityfs",
|
||||
"fuse.gvfsd-fuse",
|
||||
"fusectl",
|
||||
"squashfs",
|
||||
"overlay", # snap packages
|
||||
}
|
||||
|
||||
if fs_type in pseudo_fs:
|
||||
continue
|
||||
|
||||
# Skip root directory (too large to scan)
|
||||
if mount_point == "/":
|
||||
continue
|
||||
|
||||
# Filter out system paths
|
||||
if any(mount_point.startswith(x) for x in excluded_paths):
|
||||
continue
|
||||
|
||||
# Only include if it exists and is readable
|
||||
if os.path.exists(mount_point) and os.access(mount_point, os.R_OK):
|
||||
mount_points.add(mount_point)
|
||||
|
||||
# If no mount points found, add common data directories
|
||||
if not mount_points:
|
||||
# Add /home if it exists and is not already a separate mount point
|
||||
common_paths = ["/home", "/data", "/mnt"]
|
||||
for path in common_paths:
|
||||
if os.path.exists(path) and os.access(path, os.R_OK):
|
||||
mount_points.add(path)
|
||||
|
||||
except (FileNotFoundError, PermissionError):
|
||||
# Fallback to common paths
|
||||
mount_points = {"/home", "/mnt", "/data"}
|
||||
|
||||
return sorted(mount_points)
|
||||
254
kt-kernel/python/cli/utils/model_table_builder.py
Normal file
254
kt-kernel/python/cli/utils/model_table_builder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Shared model table builders for consistent UI across commands.
|
||||
|
||||
Provides reusable table construction functions for displaying models
|
||||
in kt model list, kt quant, kt run, etc.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
import json
|
||||
|
||||
|
||||
def format_model_size(model_path: Path, format_type: str) -> str:
|
||||
"""Calculate and format model size."""
|
||||
from kt_kernel.cli.utils.model_scanner import format_size
|
||||
|
||||
try:
|
||||
if format_type == "safetensors":
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
elif format_type == "gguf":
|
||||
files = list(model_path.glob("*.gguf"))
|
||||
else:
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
total_size = sum(f.stat().st_size for f in files if f.exists())
|
||||
return format_size(total_size)
|
||||
except Exception:
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
|
||||
def format_repo_info(model) -> str:
|
||||
"""Format repository information."""
|
||||
if model.repo_id:
|
||||
repo_abbr = "hf" if model.repo_type == "huggingface" else "ms"
|
||||
return f"{repo_abbr}:{model.repo_id}"
|
||||
return "[dim]-[/dim]"
|
||||
|
||||
|
||||
def format_sha256_status(model, status_map: dict) -> str:
|
||||
"""Format SHA256 verification status."""
|
||||
return status_map.get(model.sha256_status or "not_checked", "[dim]?[/dim]")
|
||||
|
||||
|
||||
def build_moe_gpu_table(
|
||||
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build MoE GPU models table.
|
||||
|
||||
Args:
|
||||
models: List of MoE GPU model objects
|
||||
status_map: SHA256_STATUS_MAP for formatting status
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Exps", justify="center", style="yellow")
|
||||
table.add_column("Act", justify="center", style="green")
|
||||
table.add_column("Repository", style="dim", overflow="fold")
|
||||
table.add_column("SHA256", justify="center")
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "safetensors")
|
||||
|
||||
# MoE info
|
||||
num_experts = str(model.moe_num_experts) if model.moe_num_experts else "[dim]-[/dim]"
|
||||
num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else "[dim]-[/dim]"
|
||||
|
||||
# Repository and SHA256
|
||||
repo_str = format_repo_info(model)
|
||||
sha256_str = format_sha256_status(model, status_map)
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
return table, displayed_models
|
||||
|
||||
|
||||
def build_amx_table(
|
||||
models: List,
|
||||
status_map: dict = None, # Kept for API compatibility but not used
|
||||
show_index: bool = True,
|
||||
start_index: int = 1,
|
||||
show_linked_gpus: bool = False,
|
||||
gpu_models: Optional[List] = None,
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build AMX models table.
|
||||
|
||||
Note: AMX models are locally quantized, so no SHA256 verification column.
|
||||
|
||||
Args:
|
||||
models: List of AMX model objects
|
||||
status_map: (Unused - kept for API compatibility)
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
show_linked_gpus: Whether to show sub-rows for linked GPU models
|
||||
gpu_models: List of GPU models (required if show_linked_gpus=True)
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Method", justify="center", style="yellow")
|
||||
table.add_column("NUMA", justify="center", style="green")
|
||||
table.add_column("Source", style="dim", overflow="fold")
|
||||
|
||||
# Build reverse map if needed
|
||||
amx_used_by_gpu = {}
|
||||
if show_linked_gpus and gpu_models:
|
||||
for model in models:
|
||||
if model.gpu_model_ids:
|
||||
gpu_names = []
|
||||
for gpu_id in model.gpu_model_ids:
|
||||
for gpu_model in gpu_models:
|
||||
if gpu_model.id == gpu_id:
|
||||
gpu_names.append(gpu_model.name)
|
||||
break
|
||||
if gpu_names:
|
||||
amx_used_by_gpu[model.id] = gpu_names
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "safetensors")
|
||||
|
||||
# Read metadata from config.json or UserModel fields
|
||||
method_from_config = None
|
||||
numa_from_config = None
|
||||
try:
|
||||
config_path = Path(model.path) / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
amx_quant = config.get("amx_quantization", {})
|
||||
if amx_quant.get("converted"):
|
||||
method_from_config = amx_quant.get("method")
|
||||
numa_from_config = amx_quant.get("numa_count")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Priority: UserModel fields > config.json > ?
|
||||
method_display = (
|
||||
model.amx_quant_method.upper()
|
||||
if model.amx_quant_method
|
||||
else method_from_config.upper() if method_from_config else "[dim]?[/dim]"
|
||||
)
|
||||
numa_display = (
|
||||
str(model.amx_numa_nodes)
|
||||
if model.amx_numa_nodes
|
||||
else str(numa_from_config) if numa_from_config else "[dim]?[/dim]"
|
||||
)
|
||||
source_display = model.amx_source_model or "[dim]-[/dim]"
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, method_display, numa_display, source_display])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
# Add sub-row showing linked GPUs
|
||||
if show_linked_gpus and model.id in amx_used_by_gpu:
|
||||
gpu_list = amx_used_by_gpu[model.id]
|
||||
gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list])
|
||||
sub_row = []
|
||||
if show_index:
|
||||
sub_row.append("")
|
||||
sub_row.extend([f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""])
|
||||
table.add_row(*sub_row, style="dim")
|
||||
|
||||
return table, displayed_models
|
||||
|
||||
|
||||
def build_gguf_table(
|
||||
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
|
||||
) -> Tuple[Table, List]:
|
||||
"""
|
||||
Build GGUF models table.
|
||||
|
||||
Args:
|
||||
models: List of GGUF model objects
|
||||
status_map: SHA256_STATUS_MAP for formatting status
|
||||
show_index: Whether to show # column for selection (default: True)
|
||||
start_index: Starting index number
|
||||
|
||||
Returns:
|
||||
Tuple of (Table object, list of models in display order)
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold", show_lines=False)
|
||||
|
||||
if show_index:
|
||||
table.add_column("#", justify="right", style="cyan", no_wrap=True)
|
||||
|
||||
table.add_column("Name", style="cyan", no_wrap=True)
|
||||
table.add_column("Path", style="dim", overflow="fold")
|
||||
table.add_column("Total", justify="right")
|
||||
table.add_column("Repository", style="dim", overflow="fold")
|
||||
table.add_column("SHA256", justify="center")
|
||||
|
||||
displayed_models = []
|
||||
|
||||
for i, model in enumerate(models, start_index):
|
||||
displayed_models.append(model)
|
||||
|
||||
# Calculate size
|
||||
size_str = format_model_size(Path(model.path), "gguf")
|
||||
|
||||
# Repository and SHA256
|
||||
repo_str = format_repo_info(model)
|
||||
sha256_str = format_sha256_status(model, status_map)
|
||||
|
||||
row = []
|
||||
if show_index:
|
||||
row.append(str(i))
|
||||
|
||||
row.extend([model.name, model.path, size_str, repo_str, sha256_str])
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
return table, displayed_models
|
||||
918
kt-kernel/python/cli/utils/model_verifier.py
Normal file
918
kt-kernel/python/cli/utils/model_verifier.py
Normal file
@@ -0,0 +1,918 @@
|
||||
"""
|
||||
Model Verifier
|
||||
|
||||
SHA256 verification for model integrity
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal, Tuple
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:
|
||||
"""
|
||||
Compute SHA256 for a single file (worker function for multiprocessing).
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Tuple of (filename, sha256_hash, file_size_mb)
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Read file in chunks to handle large files
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks
|
||||
sha256_hash.update(byte_block)
|
||||
|
||||
return file_path.name, sha256_hash.hexdigest(), file_size_mb
|
||||
|
||||
|
||||
def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if HuggingFace is accessible.
|
||||
|
||||
Args:
|
||||
timeout: Connection timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (is_accessible, message)
|
||||
"""
|
||||
test_url = "https://huggingface.co"
|
||||
|
||||
try:
|
||||
response = requests.head(test_url, timeout=timeout, allow_redirects=True)
|
||||
if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible"
|
||||
return True, "HuggingFace is accessible"
|
||||
except requests.exceptions.Timeout:
|
||||
return False, f"Connection to {test_url} timed out"
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False, f"Cannot connect to {test_url}"
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False, f"Connection error: {str(e)}"
|
||||
|
||||
return False, "Unknown connection error"
|
||||
|
||||
|
||||
def verify_model_integrity(
|
||||
repo_type: Literal["huggingface", "modelscope"],
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
progress_callback=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify local model integrity against remote repository SHA256 hashes.
|
||||
|
||||
Verifies all important files:
|
||||
- *.safetensors (weights)
|
||||
- *.json (config files)
|
||||
- *.py (custom model code)
|
||||
|
||||
Args:
|
||||
repo_type: Type of repository ("huggingface" or "modelscope")
|
||||
repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3")
|
||||
local_dir: Local directory containing model files
|
||||
progress_callback: Optional callback function(message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
Dictionary with verification results:
|
||||
{
|
||||
"status": "passed" | "failed" | "error",
|
||||
"files_checked": int,
|
||||
"files_passed": int,
|
||||
"files_failed": [list of filenames],
|
||||
"error_message": str (optional)
|
||||
}
|
||||
"""
|
||||
|
||||
def report_progress(msg: str):
|
||||
"""Helper to report progress"""
|
||||
if progress_callback:
|
||||
progress_callback(msg)
|
||||
|
||||
try:
|
||||
# Convert repo_type to platform format
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
|
||||
# 1. Fetch official SHA256 hashes from remote
|
||||
report_progress("Fetching official SHA256 hashes from remote repository...")
|
||||
official_hashes = fetch_model_sha256(repo_id, platform)
|
||||
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
|
||||
|
||||
if not official_hashes:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No verifiable files found in remote repository: {repo_id}",
|
||||
}
|
||||
|
||||
# 2. Calculate local SHA256 hashes with progress
|
||||
report_progress(f"Calculating SHA256 for local files...")
|
||||
|
||||
# Get all local files matching the patterns
|
||||
local_files = []
|
||||
for pattern in ["*.safetensors", "*.json", "*.py"]:
|
||||
local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])
|
||||
|
||||
if not local_files:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No verifiable files found in local directory: {local_dir}",
|
||||
}
|
||||
|
||||
# Calculate hashes for all files
|
||||
local_hashes = calculate_local_sha256(
|
||||
local_dir,
|
||||
file_pattern="*.safetensors", # Unused when files_list is provided
|
||||
progress_callback=report_progress,
|
||||
files_list=local_files,
|
||||
)
|
||||
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
|
||||
|
||||
# 3. Compare hashes with progress
|
||||
report_progress(f"Comparing {len(official_hashes)} files...")
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
|
||||
# Handle potential path separators in filename
|
||||
file_basename = Path(filename).name
|
||||
|
||||
# Try to find the file in local hashes
|
||||
local_hash = None
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING")
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH")
|
||||
else:
|
||||
files_passed += 1
|
||||
report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}")
|
||||
|
||||
# 4. Return results
|
||||
total_checked = len(official_hashes)
|
||||
|
||||
if files_failed or files_missing:
|
||||
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
|
||||
return {
|
||||
"status": "failed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": all_failed,
|
||||
"error_message": f"{len(all_failed)} file(s) failed verification",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "passed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": [],
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope",
|
||||
"is_network_error": False,
|
||||
}
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.RequestException,
|
||||
) as e:
|
||||
# Network-related errors - suggest mirror
|
||||
error_msg = f"Network error: {str(e)}"
|
||||
if repo_type == "huggingface":
|
||||
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": error_msg,
|
||||
"is_network_error": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Verification failed: {str(e)}",
|
||||
"is_network_error": False,
|
||||
}
|
||||
|
||||
|
||||
def calculate_local_sha256(
|
||||
local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Calculate SHA256 hashes for files in a directory using parallel processing.
|
||||
|
||||
Args:
|
||||
local_dir: Directory to scan
|
||||
file_pattern: Glob pattern for files to hash (ignored if files_list is provided)
|
||||
progress_callback: Optional callback function(message: str) for progress updates
|
||||
files_list: Optional pre-filtered list of files to hash (overrides file_pattern)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filename to SHA256 hash
|
||||
"""
|
||||
result = {}
|
||||
|
||||
if not local_dir.exists():
|
||||
return result
|
||||
|
||||
# Get all files first to report total
|
||||
if files_list is not None:
|
||||
files_to_hash = files_list
|
||||
else:
|
||||
files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]
|
||||
total_files = len(files_to_hash)
|
||||
|
||||
if total_files == 0:
|
||||
return result
|
||||
|
||||
# Use min(16, total_files) workers to avoid over-spawning processes
|
||||
max_workers = min(16, total_files)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation")
|
||||
|
||||
# Use ProcessPoolExecutor for CPU-intensive SHA256 computation
|
||||
completed_count = 0
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(future_to_file):
|
||||
completed_count += 1
|
||||
try:
|
||||
filename, sha256_hash, file_size_mb = future.result()
|
||||
result[filename] = sha256_hash
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
file_path = future_to_file[future]
|
||||
if progress_callback:
|
||||
progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_model_sha256(
|
||||
repo_id: str,
|
||||
platform: Literal["hf", "ms"],
|
||||
revision: str | None = None,
|
||||
use_mirror: bool = False,
|
||||
timeout: int | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
获取模型仓库中所有重要文件的 sha256 哈希值。
|
||||
|
||||
包括:
|
||||
- *.safetensors (权重文件)
|
||||
- *.json (配置文件:config.json, tokenizer_config.json 等)
|
||||
- *.py (自定义模型代码:modeling.py, configuration.py 等)
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID,例如 "Qwen/Qwen3-30B-A3B"
|
||||
platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope)
|
||||
revision: 版本/分支,默认 HuggingFace 为 "main",ModelScope 为 "master"
|
||||
use_mirror: 是否使用镜像(仅对 HuggingFace 有效)
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
|
||||
Returns:
|
||||
dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."}
|
||||
"""
|
||||
if platform == "hf":
|
||||
# 先尝试直连,失败后自动使用镜像
|
||||
try:
|
||||
if use_mirror:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
|
||||
else:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout)
|
||||
except Exception as e:
|
||||
# 如果不是镜像模式且失败了,自动重试使用镜像
|
||||
if not use_mirror:
|
||||
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
|
||||
else:
|
||||
raise e
|
||||
elif platform == "ms":
|
||||
return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout)
|
||||
else:
|
||||
raise ValueError(f"不支持的平台: {platform},请使用 'hf' 或 'ms'")
|
||||
|
||||
|
||||
def _fetch_from_huggingface(
|
||||
repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None
|
||||
) -> dict[str, str]:
|
||||
"""从 HuggingFace 获取所有重要文件的 sha256
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID
|
||||
revision: 版本/分支
|
||||
use_mirror: 是否使用镜像(hf-mirror.com)
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
"""
|
||||
import os
|
||||
import socket
|
||||
|
||||
# 如果需要使用镜像,设置环境变量
|
||||
original_endpoint = os.environ.get("HF_ENDPOINT")
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
# Set socket timeout if specified
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
if timeout is not None:
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
from huggingface_hub import HfApi, list_repo_files
|
||||
|
||||
try:
|
||||
api = HfApi()
|
||||
all_files = list_repo_files(repo_id=repo_id, revision=revision)
|
||||
|
||||
# 筛选重要文件:*.safetensors, *.json, *.py
|
||||
important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))]
|
||||
|
||||
if not important_files:
|
||||
return {}
|
||||
|
||||
paths_info = api.get_paths_info(
|
||||
repo_id=repo_id,
|
||||
paths=important_files,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
result = {}
|
||||
for file_info in paths_info:
|
||||
if hasattr(file_info, "lfs") and file_info.lfs is not None:
|
||||
sha256 = file_info.lfs.sha256
|
||||
else:
|
||||
sha256 = getattr(file_info, "blob_id", None)
|
||||
result[file_info.path] = sha256
|
||||
|
||||
return result
|
||||
finally:
|
||||
# 恢复原始 socket timeout
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
|
||||
# 恢复原始环境变量
|
||||
if use_mirror and not original_endpoint:
|
||||
os.environ.pop("HF_ENDPOINT", None)
|
||||
elif original_endpoint:
|
||||
os.environ["HF_ENDPOINT"] = original_endpoint
|
||||
|
||||
|
||||
def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:
|
||||
"""从 ModelScope 获取所有重要文件的 sha256
|
||||
|
||||
Args:
|
||||
repo_id: 仓库 ID
|
||||
revision: 版本/分支
|
||||
timeout: 网络请求超时时间(秒),None 表示不设置超时
|
||||
"""
|
||||
import socket
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
# Set socket timeout if specified
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
if timeout is not None:
|
||||
socket.setdefaulttimeout(timeout)
|
||||
|
||||
try:
|
||||
api = HubApi()
|
||||
files_info = api.get_model_files(model_id=repo_id, revision=revision)
|
||||
|
||||
result = {}
|
||||
for file_info in files_info:
|
||||
filename = file_info.get("Name", file_info.get("Path", ""))
|
||||
# 筛选重要文件:*.safetensors, *.json, *.py
|
||||
if filename.endswith((".safetensors", ".json", ".py")):
|
||||
sha256 = file_info.get("Sha256", file_info.get("sha256", None))
|
||||
result[filename] = sha256
|
||||
|
||||
return result
|
||||
finally:
|
||||
# 恢复原始 socket timeout
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
|
||||
|
||||
def verify_model_integrity_with_progress(
|
||||
repo_type: Literal["huggingface", "modelscope"],
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
progress_callback=None,
|
||||
verbose: bool = False,
|
||||
use_mirror: bool = False,
|
||||
files_to_verify: list[str] | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify model integrity with enhanced progress reporting for Rich Progress bars.
|
||||
|
||||
This is a wrapper around verify_model_integrity() that provides more detailed
|
||||
progress information suitable for progress bar display.
|
||||
|
||||
The progress_callback receives:
|
||||
- (message: str, total: int, current: int) for countable operations
|
||||
- (message: str) for status updates
|
||||
|
||||
Args:
|
||||
repo_type: Repository type ("huggingface" or "modelscope")
|
||||
repo_id: Repository ID
|
||||
local_dir: Local directory path
|
||||
progress_callback: Optional callback for progress updates
|
||||
verbose: If True, output detailed SHA256 comparison for each file
|
||||
use_mirror: If True, use HuggingFace mirror (hf-mirror.com)
|
||||
files_to_verify: Optional list of specific files to verify (for re-verification)
|
||||
timeout: Network request timeout in seconds (None = no timeout)
|
||||
"""
|
||||
|
||||
def report_progress(msg: str, total=None, current=None):
|
||||
"""Enhanced progress reporter"""
|
||||
if progress_callback:
|
||||
progress_callback(msg, total, current)
|
||||
|
||||
try:
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
|
||||
# 1. Fetch official SHA256 hashes
|
||||
if files_to_verify:
|
||||
report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...")
|
||||
elif use_mirror and platform == "hf":
|
||||
report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...")
|
||||
else:
|
||||
report_progress("Fetching official SHA256 hashes from remote repository...")
|
||||
|
||||
official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)
|
||||
|
||||
# Filter to only requested files if specified
|
||||
if files_to_verify:
|
||||
# Extract clean filenames from files_to_verify (remove markers like "(missing)")
|
||||
clean_filenames = set()
|
||||
for f in files_to_verify:
|
||||
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
|
||||
# Ensure we only use the filename, not full path
|
||||
clean_filenames.add(Path(clean_f).name)
|
||||
|
||||
# Filter official_hashes to only include requested files
|
||||
# Compare using basename since official_hashes keys might have paths
|
||||
official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}
|
||||
|
||||
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
|
||||
|
||||
if not official_hashes:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"No safetensors files found in remote repository: {repo_id}",
|
||||
}
|
||||
|
||||
# 2. Calculate local SHA256 hashes
|
||||
local_dir_path = Path(local_dir)
|
||||
|
||||
# Only hash the files we need to verify
|
||||
if files_to_verify:
|
||||
# Extract clean filenames (without markers)
|
||||
clean_filenames = set()
|
||||
for f in files_to_verify:
|
||||
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
|
||||
# Ensure we only use the filename, not full path
|
||||
clean_filenames.add(Path(clean_f).name)
|
||||
|
||||
# Only hash files that match the clean filenames
|
||||
files_to_hash = [
|
||||
f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames
|
||||
]
|
||||
else:
|
||||
files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()]
|
||||
|
||||
total_files = len(files_to_hash)
|
||||
|
||||
if files_to_verify:
|
||||
report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0)
|
||||
else:
|
||||
report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0)
|
||||
|
||||
# Progress wrapper for hashing
|
||||
completed_count = [0] # Use list for mutable closure
|
||||
|
||||
def hash_progress_callback(msg: str):
|
||||
if "Using" in msg and "workers" in msg:
|
||||
report_progress(msg)
|
||||
elif "[" in msg and "/" in msg and "]" in msg:
|
||||
# Progress update like: [1/10] ✓ filename (123.4 MB)
|
||||
completed_count[0] += 1
|
||||
report_progress(msg, total=total_files, current=completed_count[0])
|
||||
|
||||
# Pass the pre-filtered files_to_hash list
|
||||
local_hashes = calculate_local_sha256(
|
||||
local_dir_path,
|
||||
"*.safetensors",
|
||||
progress_callback=hash_progress_callback,
|
||||
files_list=files_to_hash if files_to_verify else None,
|
||||
)
|
||||
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
|
||||
|
||||
# 3. Compare hashes
|
||||
report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0)
|
||||
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
|
||||
file_basename = Path(filename).name
|
||||
|
||||
# Find matching local file
|
||||
local_hash = None
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: <missing>",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
files_passed += 1
|
||||
if verbose:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}",
|
||||
total=len(official_hashes),
|
||||
current=idx,
|
||||
)
|
||||
else:
|
||||
report_progress(
|
||||
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx
|
||||
)
|
||||
|
||||
# 4. Return results
|
||||
total_checked = len(official_hashes)
|
||||
|
||||
if files_failed or files_missing:
|
||||
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
|
||||
return {
|
||||
"status": "failed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": all_failed,
|
||||
"error_message": f"{len(all_failed)} file(s) failed verification",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "passed",
|
||||
"files_checked": total_checked,
|
||||
"files_passed": files_passed,
|
||||
"files_failed": [],
|
||||
}
|
||||
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
requests.exceptions.RequestException,
|
||||
TimeoutError, # Socket timeout from socket.setdefaulttimeout()
|
||||
OSError, # Network-related OS errors
|
||||
) as e:
|
||||
error_msg = f"Network error: {str(e)}"
|
||||
if repo_type == "huggingface":
|
||||
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": error_msg,
|
||||
"is_network_error": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"files_checked": 0,
|
||||
"files_passed": 0,
|
||||
"files_failed": [],
|
||||
"error_message": f"Verification failed: {str(e)}",
|
||||
"is_network_error": False,
|
||||
}
|
||||
|
||||
|
||||
def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None:
|
||||
"""Pre-operation verification of model integrity.
|
||||
|
||||
Can be used before running or quantizing models to ensure integrity.
|
||||
|
||||
Args:
|
||||
user_model: UserModel object to verify
|
||||
user_registry: UserModelRegistry instance
|
||||
operation_name: Name of the operation (e.g., "running", "quantizing")
|
||||
"""
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step
|
||||
import typer
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Check if already verified
|
||||
if user_model.sha256_status == "passed":
|
||||
console.print()
|
||||
print_info("Model integrity already verified ✓")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Model not verified yet
|
||||
console.print()
|
||||
console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]")
|
||||
console.print()
|
||||
|
||||
# Check if repo_id exists
|
||||
if not user_model.repo_id:
|
||||
# No repo_id - ask user to provide one
|
||||
console.print("[yellow]No repository ID configured for this model.[/yellow]")
|
||||
console.print()
|
||||
console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask("Would you like to configure repository ID now?", default=True):
|
||||
console.print()
|
||||
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Ask for repo type
|
||||
console.print()
|
||||
console.print("Repository type:")
|
||||
console.print(" [cyan][1][/cyan] HuggingFace")
|
||||
console.print(" [cyan][2][/cyan] ModelScope")
|
||||
console.print()
|
||||
|
||||
repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1")
|
||||
repo_type = "huggingface" if repo_type_choice == "1" else "modelscope"
|
||||
|
||||
# Ask for repo_id
|
||||
console.print()
|
||||
repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)")
|
||||
|
||||
# Update model
|
||||
user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id})
|
||||
user_model.repo_type = repo_type
|
||||
user_model.repo_id = repo_id
|
||||
|
||||
console.print()
|
||||
print_success(f"Repository configured: {repo_type}:{repo_id}")
|
||||
console.print()
|
||||
|
||||
# Now ask if user wants to verify
|
||||
console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]")
|
||||
console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True):
|
||||
console.print()
|
||||
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Perform verification
|
||||
console.print()
|
||||
print_step("Verifying model integrity...")
|
||||
console.print()
|
||||
|
||||
# Check connectivity first
|
||||
use_mirror = False
|
||||
if user_model.repo_type == "huggingface":
|
||||
with console.status("[dim]Checking HuggingFace connectivity...[/dim]"):
|
||||
is_accessible, message = check_huggingface_connectivity(timeout=5)
|
||||
|
||||
if not is_accessible:
|
||||
print_warning("HuggingFace Connection Failed")
|
||||
console.print()
|
||||
console.print(f" {message}")
|
||||
console.print()
|
||||
console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]")
|
||||
console.print()
|
||||
use_mirror = True
|
||||
|
||||
# Fetch remote hashes with timeout
|
||||
def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):
|
||||
"""Fetch hashes with timeout."""
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
platform = "hf" if repo_type == "huggingface" else "ms"
|
||||
future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)
|
||||
hashes = future.result(timeout=timeout)
|
||||
executor.shutdown(wait=False)
|
||||
return (hashes, False)
|
||||
except (FutureTimeoutError, Exception):
|
||||
executor.shutdown(wait=False)
|
||||
return (None, True)
|
||||
|
||||
# Try fetching hashes
|
||||
status = console.status("[dim]Fetching remote hashes...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)
|
||||
status.stop()
|
||||
|
||||
# Handle timeout with fallback
|
||||
if timed_out and user_model.repo_type == "huggingface" and not use_mirror:
|
||||
print_warning("HuggingFace Fetch Timeout (10s)")
|
||||
console.print()
|
||||
console.print(" [yellow]Trying HuggingFace mirror...[/yellow]")
|
||||
console.print()
|
||||
|
||||
status = console.status("[dim]Fetching remote hashes from mirror...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)
|
||||
status.stop()
|
||||
|
||||
if timed_out and user_model.repo_type == "huggingface":
|
||||
print_warning("HuggingFace Mirror Timeout (10s)")
|
||||
console.print()
|
||||
console.print(" [yellow]Fallback to ModelScope...[/yellow]")
|
||||
console.print()
|
||||
|
||||
status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]")
|
||||
status.start()
|
||||
official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10)
|
||||
status.stop()
|
||||
|
||||
if not official_hashes or timed_out:
|
||||
print_error("Failed to fetch remote hashes (network timeout)")
|
||||
console.print()
|
||||
console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"Continue {operation_name} without verification?", default=False):
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print()
|
||||
return
|
||||
|
||||
console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]")
|
||||
console.print()
|
||||
|
||||
# Calculate local hashes and compare
|
||||
local_dir = Path(user_model.path)
|
||||
files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()]
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
# Calculate local hashes
|
||||
task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash))
|
||||
|
||||
def hash_callback(msg):
|
||||
if "[" in msg and "/" in msg and "]" in msg and "✓" in msg:
|
||||
progress.advance(task)
|
||||
|
||||
local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback)
|
||||
progress.remove_task(task)
|
||||
|
||||
console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]")
|
||||
console.print()
|
||||
|
||||
# Compare hashes
|
||||
task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes))
|
||||
|
||||
files_failed = []
|
||||
files_missing = []
|
||||
files_passed = 0
|
||||
|
||||
for filename, official_hash in official_hashes.items():
|
||||
file_basename = Path(filename).name
|
||||
local_hash = None
|
||||
|
||||
for local_file, local_hash_value in local_hashes.items():
|
||||
if Path(local_file).name == file_basename:
|
||||
local_hash = local_hash_value
|
||||
break
|
||||
|
||||
if local_hash is None:
|
||||
files_missing.append(filename)
|
||||
elif local_hash.lower() != official_hash.lower():
|
||||
files_failed.append(f"{filename} (hash mismatch)")
|
||||
else:
|
||||
files_passed += 1
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
progress.remove_task(task)
|
||||
|
||||
console.print()
|
||||
|
||||
# Check results
|
||||
if not files_failed and not files_missing:
|
||||
# Verification passed
|
||||
user_registry.update_model(user_model.name, {"sha256_status": "passed"})
|
||||
print_success("Model integrity verification PASSED ✓")
|
||||
console.print()
|
||||
console.print(f" All {files_passed} files verified successfully")
|
||||
console.print()
|
||||
else:
|
||||
# Verification failed
|
||||
user_registry.update_model(user_model.name, {"sha256_status": "failed"})
|
||||
print_error(f"Model integrity verification FAILED")
|
||||
console.print()
|
||||
console.print(f" ✓ Passed: [green]{files_passed}[/green]")
|
||||
console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]")
|
||||
console.print()
|
||||
|
||||
if files_missing:
|
||||
console.print(f" [red]Missing files ({len(files_missing)}):[/red]")
|
||||
for f in files_missing[:5]:
|
||||
console.print(f" - {Path(f).name}")
|
||||
if len(files_missing) > 5:
|
||||
console.print(f" ... and {len(files_missing) - 5} more")
|
||||
console.print()
|
||||
|
||||
if files_failed:
|
||||
console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]")
|
||||
for f in files_failed[:5]:
|
||||
console.print(f" - {f}")
|
||||
if len(files_failed) > 5:
|
||||
console.print(f" ... and {len(files_failed) - 5} more")
|
||||
console.print()
|
||||
|
||||
console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]")
|
||||
console.print()
|
||||
console.print("This could cause runtime errors or incorrect inference results.")
|
||||
console.print()
|
||||
|
||||
# Ask if user wants to repair
|
||||
if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True):
|
||||
console.print()
|
||||
print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]")
|
||||
console.print()
|
||||
console.print("The verify command will guide you through the repair process.")
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Ask if user wants to continue anyway
|
||||
console.print()
|
||||
if not Confirm.ask(
|
||||
f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False
|
||||
):
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print()
|
||||
print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...")
|
||||
console.print()
|
||||
57
kt-kernel/python/cli/utils/port_checker.py
Normal file
57
kt-kernel/python/cli/utils/port_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Port availability checking utilities.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def is_port_available(host: str, port: int) -> bool:
|
||||
"""Check if a port is available on the given host.
|
||||
|
||||
Args:
|
||||
host: Host address (e.g., "0.0.0.0", "127.0.0.1")
|
||||
port: Port number to check
|
||||
|
||||
Returns:
|
||||
True if port is available, False if occupied
|
||||
"""
|
||||
try:
|
||||
# Try to bind to the port
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
|
||||
# Use SO_REUSEADDR to allow binding to recently closed ports
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
# Try to bind
|
||||
result = sock.connect_ex((host if host != "0.0.0.0" else "127.0.0.1", port))
|
||||
sock.close()
|
||||
|
||||
# If connect_ex returns 0, port is occupied
|
||||
# If it returns error (non-zero), port is available
|
||||
return result != 0
|
||||
|
||||
except Exception:
|
||||
# If any error occurs, assume port is not available
|
||||
return False
|
||||
|
||||
|
||||
def find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]:
|
||||
"""Find an available port starting from start_port.
|
||||
|
||||
Args:
|
||||
host: Host address
|
||||
start_port: Starting port number to check
|
||||
max_attempts: Maximum number of ports to try
|
||||
|
||||
Returns:
|
||||
Tuple of (found, port_number)
|
||||
- found: True if an available port was found
|
||||
- port_number: The available port number (or start_port if not found)
|
||||
"""
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if is_port_available(host, port):
|
||||
return True, port
|
||||
|
||||
return False, start_port
|
||||
347
kt-kernel/python/cli/utils/quant_interactive.py
Normal file
347
kt-kernel/python/cli/utils/quant_interactive.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Interactive configuration for kt quant command.
|
||||
|
||||
Provides rich, multi-step interactive configuration for model quantization.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm, IntPrompt
|
||||
from kt_kernel.cli.i18n import t
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def select_model_to_quantize() -> Optional[Any]:
|
||||
"""Select model to quantize interactively."""
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP
|
||||
from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table
|
||||
|
||||
registry = UserModelRegistry()
|
||||
all_models = registry.list_models()
|
||||
|
||||
# Filter MoE models only (safetensors, not AMX, is_moe=True)
|
||||
quant_models = []
|
||||
for model in all_models:
|
||||
if model.format == "safetensors":
|
||||
# Skip AMX models
|
||||
is_amx, _ = is_amx_weights(model.path)
|
||||
if is_amx:
|
||||
continue
|
||||
|
||||
# Only include MoE models
|
||||
if model.is_moe:
|
||||
quant_models.append(model)
|
||||
|
||||
if not quant_models:
|
||||
console.print(f"[yellow]{t('quant_no_moe_models')}[/yellow]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_only_moe')}")
|
||||
console.print()
|
||||
console.print(f" {t('quant_add_models', command='kt model scan')}")
|
||||
console.print(f" {t('quant_add_models', command='kt model add <path>')}")
|
||||
return None
|
||||
|
||||
# Display models
|
||||
console.print()
|
||||
console.print(f"[bold green]{t('quant_moe_available')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
# Use shared table builder
|
||||
table, displayed_models = build_moe_gpu_table(
|
||||
models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
choice = IntPrompt.ask(t("quant_select_model"), default=1, show_choices=False)
|
||||
|
||||
if choice < 1 or choice > len(displayed_models):
|
||||
console.print(f"[red]{t('quant_invalid_choice')}[/red]")
|
||||
return None
|
||||
|
||||
return displayed_models[choice - 1]
|
||||
|
||||
|
||||
def configure_quantization_method() -> Dict[str, str]:
|
||||
"""Select quantization method and input type."""
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step2_method')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
# Method selection
|
||||
console.print(f"[bold]{t('quant_method_label')}[/bold]")
|
||||
console.print(f" [cyan][1][/cyan] {t('quant_int4_desc')}")
|
||||
console.print(f" [cyan][2][/cyan] {t('quant_int8_desc')}")
|
||||
console.print()
|
||||
|
||||
method_choice = Prompt.ask(t("quant_select_method"), choices=["1", "2"], default="1")
|
||||
method = "int4" if method_choice == "1" else "int8"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]{t('quant_input_type_label')}[/bold]")
|
||||
console.print(f" [cyan][1][/cyan] {t('quant_fp8_desc')}")
|
||||
console.print(f" [cyan][2][/cyan] {t('quant_fp16_desc')}")
|
||||
console.print(f" [cyan][3][/cyan] {t('quant_bf16_desc')}")
|
||||
console.print()
|
||||
|
||||
input_choice = Prompt.ask(t("quant_select_input_type"), choices=["1", "2", "3"], default="1")
|
||||
input_type_map = {"1": "fp8", "2": "fp16", "3": "bf16"}
|
||||
input_type = input_type_map[input_choice]
|
||||
|
||||
return {"method": method, "input_type": input_type}
|
||||
|
||||
|
||||
def configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]:
|
||||
"""Configure CPU parameters."""
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
def clamp(value: int, min_val: int, max_val: int, default: int) -> int:
|
||||
"""Clamp value to range or return default if out of bounds."""
|
||||
if min_val <= value <= max_val:
|
||||
return max(min_val, min(value, max_val))
|
||||
return default
|
||||
|
||||
default_threads = int(max_cores * 0.8)
|
||||
cpu_threads = IntPrompt.ask(t("quant_cpu_threads_prompt", max=max_cores), default=default_threads)
|
||||
cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads)
|
||||
|
||||
numa_nodes = IntPrompt.ask(t("quant_numa_nodes_prompt", max=max_numa), default=max_numa)
|
||||
numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa)
|
||||
|
||||
# Ask about GPU usage
|
||||
console.print()
|
||||
console.print(f"[bold]{t('quant_use_gpu_label')}[/bold]")
|
||||
console.print(f" [dim]{t('quant_gpu_speedup')}[/dim]")
|
||||
console.print()
|
||||
use_gpu = Confirm.ask(t("quant_enable_gpu"), default=True)
|
||||
|
||||
return {"cpu_threads": cpu_threads, "numa_nodes": numa_nodes, "use_gpu": use_gpu}
|
||||
|
||||
|
||||
def configure_output_path(model: Any, method: str, numa_nodes: int) -> Path:
|
||||
"""Configure output path for quantized weights."""
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_step4_output')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
# Generate default output path
|
||||
model_path = Path(model.path)
|
||||
method_upper = method.upper()
|
||||
settings = get_settings()
|
||||
|
||||
# Priority: paths.weights > paths.models[0] > model's parent directory
|
||||
weights_dir = settings.weights_dir
|
||||
if weights_dir and weights_dir.exists():
|
||||
# Use configured weights directory (highest priority)
|
||||
default_output = weights_dir / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
else:
|
||||
# Use first model storage path
|
||||
model_paths = settings.get_model_paths()
|
||||
if model_paths and model_paths[0].exists():
|
||||
default_output = model_paths[0] / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
else:
|
||||
# Fallback to model's parent directory
|
||||
default_output = model_path.parent / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
|
||||
|
||||
console.print(f"[dim]{t('quant_default_path')}[/dim]", default_output)
|
||||
console.print()
|
||||
|
||||
use_default = Confirm.ask(t("quant_use_default"), default=True)
|
||||
|
||||
if use_default:
|
||||
return default_output
|
||||
|
||||
custom_path = Prompt.ask(t("quant_custom_path"), default=str(default_output))
|
||||
|
||||
return Path(custom_path)
|
||||
|
||||
|
||||
def calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate source model size and estimated quantized size.
|
||||
|
||||
Args:
|
||||
source_path: Path to source model
|
||||
input_type: Input type (fp8, fp16, bf16)
|
||||
quant_method: Quantization method (int4, int8)
|
||||
|
||||
Returns:
|
||||
Tuple of (source_size_gb, estimated_quant_size_gb)
|
||||
"""
|
||||
# Calculate source model size
|
||||
try:
|
||||
total_bytes = sum(f.stat().st_size for f in source_path.glob("*.safetensors") if f.is_file())
|
||||
source_size_gb = total_bytes / (1024**3)
|
||||
except Exception:
|
||||
return 0.0, 0.0
|
||||
|
||||
# Bits mapping
|
||||
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
|
||||
quant_bits = {"int4": 4, "int8": 8}
|
||||
|
||||
input_bit = input_bits.get(input_type, 16)
|
||||
quant_bit = quant_bits.get(quant_method, 4)
|
||||
|
||||
# Estimate: source_size * (quant_bits / input_bits)
|
||||
ratio = quant_bit / input_bit
|
||||
estimated_size_gb = source_size_gb * ratio
|
||||
|
||||
return source_size_gb, estimated_size_gb
|
||||
|
||||
|
||||
def check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]:
|
||||
"""
|
||||
Check available disk space at output path.
|
||||
|
||||
Args:
|
||||
output_path: Target output path
|
||||
required_size_gb: Required space in GB
|
||||
|
||||
Returns:
|
||||
Tuple of (available_gb, is_sufficient)
|
||||
is_sufficient is True if available >= required * 1.2
|
||||
"""
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# Get parent directory that exists
|
||||
check_path = output_path.parent if not output_path.exists() else output_path
|
||||
while not check_path.exists() and check_path != check_path.parent:
|
||||
check_path = check_path.parent
|
||||
|
||||
stat = shutil.disk_usage(check_path)
|
||||
available_gb = stat.free / (1024**3)
|
||||
|
||||
# Check if available space >= required * 1.2 (20% buffer)
|
||||
is_sufficient = available_gb >= (required_size_gb * 1.2)
|
||||
|
||||
return available_gb, is_sufficient
|
||||
except Exception:
|
||||
return 0.0, False
|
||||
|
||||
|
||||
def interactive_quant_config() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Interactive configuration for kt quant.
|
||||
|
||||
Returns configuration dict or None if cancelled.
|
||||
"""
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info
|
||||
|
||||
# Get CPU info
|
||||
cpu_info = detect_cpu_info()
|
||||
|
||||
# Step 1: Select model
|
||||
model = select_model_to_quantize()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
# Step 1.5: Pre-quantization verification (optional)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(model.path)
|
||||
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
|
||||
|
||||
# Step 2: Configure quantization method
|
||||
quant_config = configure_quantization_method()
|
||||
|
||||
# Step 3: Configure CPU parameters
|
||||
cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes) # Use logical threads
|
||||
|
||||
# Step 4: Configure output path
|
||||
output_path = configure_output_path(model, quant_config["method"], cpu_config["numa_nodes"])
|
||||
|
||||
# Step 4.5: Check if output path already exists and generate unique name
|
||||
if output_path.exists():
|
||||
console.print()
|
||||
console.print(t("quant_output_exists_warn", path=str(output_path)))
|
||||
console.print()
|
||||
|
||||
# Generate unique name by adding suffix
|
||||
original_name = output_path.name
|
||||
parent_dir = output_path.parent
|
||||
counter = 2
|
||||
|
||||
while output_path.exists():
|
||||
new_name = f"{original_name}-{counter}"
|
||||
output_path = parent_dir / new_name
|
||||
counter += 1
|
||||
|
||||
console.print(t("quant_using_unique_name", path=str(output_path)))
|
||||
console.print()
|
||||
|
||||
# Step 5: Calculate space requirements and check availability
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
|
||||
source_size_gb, estimated_size_gb = calculate_quantized_size(
|
||||
Path(model.path), quant_config["input_type"], quant_config["method"]
|
||||
)
|
||||
|
||||
available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb)
|
||||
|
||||
console.print(f" {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]")
|
||||
console.print(f" {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]")
|
||||
console.print(
|
||||
f" {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]"
|
||||
)
|
||||
console.print()
|
||||
|
||||
if not is_sufficient:
|
||||
required_with_buffer = estimated_size_gb * 1.2
|
||||
console.print(f"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]")
|
||||
console.print(f" {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]")
|
||||
console.print(f" {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]")
|
||||
console.print()
|
||||
console.print(f" {t('quant_may_fail')}")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"[yellow]{t('quant_continue_anyway')}[/yellow]", default=False):
|
||||
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
|
||||
return None
|
||||
console.print()
|
||||
|
||||
# Summary and confirmation
|
||||
console.print()
|
||||
console.print(Panel(f"[bold cyan]{t('quant_config_summary')}[/bold cyan]", expand=False))
|
||||
console.print()
|
||||
console.print(f" {t('quant_summary_model'):<15} {model.name}")
|
||||
console.print(f" {t('quant_summary_method'):<15} {quant_config['method'].upper()}")
|
||||
console.print(f" {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}")
|
||||
console.print(f" {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}")
|
||||
console.print(f" {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}")
|
||||
console.print(f" {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}")
|
||||
console.print(f" {t('quant_summary_output'):<15} {output_path}")
|
||||
console.print()
|
||||
|
||||
if not Confirm.ask(f"[bold green]{t('quant_start_question')}[/bold green]", default=True):
|
||||
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
|
||||
return None
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"method": quant_config["method"],
|
||||
"input_type": quant_config["input_type"],
|
||||
"cpu_threads": cpu_config["cpu_threads"],
|
||||
"numa_nodes": cpu_config["numa_nodes"],
|
||||
"use_gpu": cpu_config["use_gpu"],
|
||||
"output_path": output_path,
|
||||
}
|
||||
364
kt-kernel/python/cli/utils/repo_detector.py
Normal file
364
kt-kernel/python/cli/utils/repo_detector.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Repo Detector
|
||||
|
||||
Automatically detect repository information from model README.md files
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Tuple
|
||||
import yaml
|
||||
|
||||
|
||||
def parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]:
|
||||
"""
|
||||
Parse YAML frontmatter from README.md
|
||||
|
||||
Args:
|
||||
readme_path: Path to README.md file
|
||||
|
||||
Returns:
|
||||
Dictionary of frontmatter data, or None if not found
|
||||
"""
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Match YAML frontmatter between --- markers
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
yaml_content = match.group(1)
|
||||
|
||||
# Parse YAML
|
||||
try:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
return data if isinstance(data, dict) else None
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo_id and repo_type from frontmatter
|
||||
|
||||
Args:
|
||||
frontmatter: Parsed YAML frontmatter dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None
|
||||
repo_type is either "huggingface" or "modelscope"
|
||||
"""
|
||||
if not frontmatter:
|
||||
return None
|
||||
|
||||
# Priority 1: Extract from license_link (most reliable)
|
||||
license_link = frontmatter.get("license_link")
|
||||
if license_link and isinstance(license_link, str):
|
||||
result = _extract_repo_from_url(license_link)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Priority 2: Try to find repo_id from other fields
|
||||
repo_id = None
|
||||
|
||||
# Check base_model field
|
||||
base_model = frontmatter.get("base_model")
|
||||
if base_model:
|
||||
if isinstance(base_model, list) and len(base_model) > 0:
|
||||
# base_model is a list, take first item
|
||||
repo_id = base_model[0]
|
||||
elif isinstance(base_model, str):
|
||||
repo_id = base_model
|
||||
|
||||
# Check model-index field
|
||||
if not repo_id:
|
||||
model_index = frontmatter.get("model-index")
|
||||
if isinstance(model_index, list) and len(model_index) > 0:
|
||||
first_model = model_index[0]
|
||||
if isinstance(first_model, dict):
|
||||
repo_id = first_model.get("name")
|
||||
|
||||
# Check model_name field
|
||||
if not repo_id:
|
||||
repo_id = frontmatter.get("model_name")
|
||||
|
||||
if not repo_id or not isinstance(repo_id, str):
|
||||
return None
|
||||
|
||||
# Validate format: should be "namespace/model-name"
|
||||
if "/" not in repo_id:
|
||||
return None
|
||||
|
||||
parts = repo_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
# Determine repo type
|
||||
repo_type = "huggingface" # Default
|
||||
|
||||
# Look for ModelScope indicators
|
||||
if "modelscope" in repo_id.lower():
|
||||
repo_type = "modelscope"
|
||||
|
||||
# Check tags
|
||||
tags = frontmatter.get("tags", [])
|
||||
if isinstance(tags, list):
|
||||
if "modelscope" in [str(t).lower() for t in tags]:
|
||||
repo_type = "modelscope"
|
||||
|
||||
return (repo_id, repo_type)
|
||||
|
||||
|
||||
def _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo_id and repo_type from a URL
|
||||
|
||||
Supports:
|
||||
- https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE
|
||||
- https://modelscope.cn/models/Qwen/Qwen3-30B-A3B
|
||||
|
||||
Args:
|
||||
url: URL string
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None
|
||||
"""
|
||||
# HuggingFace pattern: https://huggingface.co/{namespace}/{model}/...
|
||||
hf_match = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", url)
|
||||
if hf_match:
|
||||
namespace = hf_match.group(1)
|
||||
model_name = hf_match.group(2)
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
return (repo_id, "huggingface")
|
||||
|
||||
# ModelScope pattern: https://modelscope.cn/models/{namespace}/{model}
|
||||
ms_match = re.match(r"https?://(?:www\.)?modelscope\.cn/models/([^/]+)/([^/]+)", url)
|
||||
if ms_match:
|
||||
namespace = ms_match.group(1)
|
||||
model_name = ms_match.group(2)
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
return (repo_id, "modelscope")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Extract repo info by globally searching for URLs in README.md
|
||||
|
||||
Args:
|
||||
readme_path: Path to README.md file
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None if not found
|
||||
"""
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find all HuggingFace URLs
|
||||
hf_pattern = r"https?://huggingface\.co/([^/\s]+)/([^/\s\)]+)"
|
||||
hf_matches = re.findall(hf_pattern, content)
|
||||
|
||||
# Find all ModelScope URLs
|
||||
ms_pattern = r"https?://(?:www\.)?modelscope\.cn/models/([^/\s]+)/([^/\s\)]+)"
|
||||
ms_matches = re.findall(ms_pattern, content)
|
||||
|
||||
# Collect all found repos with their types
|
||||
found_repos = []
|
||||
|
||||
for namespace, model_name in hf_matches:
|
||||
# Skip common non-repo paths
|
||||
if namespace.lower() in ["docs", "blog", "spaces", "datasets"]:
|
||||
continue
|
||||
if model_name.lower() in ["tree", "blob", "raw", "resolve", "discussions"]:
|
||||
continue
|
||||
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
found_repos.append((repo_id, "huggingface"))
|
||||
|
||||
for namespace, model_name in ms_matches:
|
||||
repo_id = f"{namespace}/{model_name}"
|
||||
found_repos.append((repo_id, "modelscope"))
|
||||
|
||||
if not found_repos:
|
||||
return None
|
||||
|
||||
# If multiple different repos found, use the last one
|
||||
# First, deduplicate
|
||||
seen = {}
|
||||
for repo_id, repo_type in found_repos:
|
||||
seen[repo_id] = repo_type # Will keep the last occurrence
|
||||
|
||||
# Get the last unique repo
|
||||
if seen:
|
||||
# Use the last item from found_repos that's unique
|
||||
last_unique = None
|
||||
for repo_id, repo_type in found_repos:
|
||||
if repo_id in seen:
|
||||
last_unique = (repo_id, repo_type)
|
||||
|
||||
return last_unique
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Detect repository information for a model
|
||||
|
||||
Strategy:
|
||||
Only extract from YAML frontmatter metadata in README.md
|
||||
(Removed global URL search to avoid false positives)
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory
|
||||
|
||||
Returns:
|
||||
Tuple of (repo_id, repo_type) or None if not detected
|
||||
"""
|
||||
model_dir = Path(model_path)
|
||||
|
||||
if not model_dir.exists() or not model_dir.is_dir():
|
||||
return None
|
||||
|
||||
# Look for README.md
|
||||
readme_path = model_dir / "README.md"
|
||||
if not readme_path.exists():
|
||||
return None
|
||||
|
||||
# Only parse YAML frontmatter (no fallback to global search)
|
||||
frontmatter = parse_readme_frontmatter(readme_path)
|
||||
if frontmatter:
|
||||
return extract_repo_from_frontmatter(frontmatter)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def scan_models_for_repo(model_list) -> Dict:
|
||||
"""
|
||||
Scan a list of models and detect repo information
|
||||
|
||||
Args:
|
||||
model_list: List of UserModel objects
|
||||
|
||||
Returns:
|
||||
Dictionary with scan results:
|
||||
{
|
||||
'detected': [(model, repo_id, repo_type), ...],
|
||||
'not_detected': [model, ...],
|
||||
'skipped': [model, ...] # Already has repo_id
|
||||
}
|
||||
"""
|
||||
results = {"detected": [], "not_detected": [], "skipped": []}
|
||||
|
||||
for model in model_list:
|
||||
# Skip if already has repo_id
|
||||
if model.repo_id:
|
||||
results["skipped"].append(model)
|
||||
continue
|
||||
|
||||
# Only process safetensors and gguf models
|
||||
if model.format not in ["safetensors", "gguf"]:
|
||||
results["skipped"].append(model)
|
||||
continue
|
||||
|
||||
# Try to detect repo
|
||||
repo_info = detect_repo_for_model(model.path)
|
||||
|
||||
if repo_info:
|
||||
repo_id, repo_type = repo_info
|
||||
results["detected"].append((model, repo_id, repo_type))
|
||||
else:
|
||||
results["not_detected"].append(model)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_detection_report(results: Dict) -> str:
|
||||
"""
|
||||
Format scan results into a readable report
|
||||
|
||||
Args:
|
||||
results: Results from scan_models_for_repo()
|
||||
|
||||
Returns:
|
||||
Formatted string report
|
||||
"""
|
||||
lines = []
|
||||
|
||||
lines.append("=" * 80)
|
||||
lines.append("Auto-Detection Report")
|
||||
lines.append("=" * 80)
|
||||
lines.append("")
|
||||
|
||||
# Detected
|
||||
if results["detected"]:
|
||||
lines.append(f"✓ Detected repository information ({len(results['detected'])} models):")
|
||||
lines.append("")
|
||||
for model, repo_id, repo_type in results["detected"]:
|
||||
lines.append(f" • {model.name}")
|
||||
lines.append(f" Path: {model.path}")
|
||||
lines.append(f" Repo: {repo_id} ({repo_type})")
|
||||
lines.append("")
|
||||
|
||||
# Not detected
|
||||
if results["not_detected"]:
|
||||
lines.append(f"✗ No repository information found ({len(results['not_detected'])} models):")
|
||||
lines.append("")
|
||||
for model in results["not_detected"]:
|
||||
lines.append(f" • {model.name}")
|
||||
lines.append(f" Path: {model.path}")
|
||||
lines.append("")
|
||||
|
||||
# Skipped
|
||||
if results["skipped"]:
|
||||
lines.append(f"⊘ Skipped ({len(results['skipped'])} models):")
|
||||
lines.append(f" (Already have repo_id or not safetensors/gguf format)")
|
||||
lines.append("")
|
||||
|
||||
lines.append("=" * 80)
|
||||
lines.append(
|
||||
f"Summary: {len(results['detected'])} detected, "
|
||||
f"{len(results['not_detected'])} not detected, "
|
||||
f"{len(results['skipped'])} skipped"
|
||||
)
|
||||
lines.append("=" * 80)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def apply_detection_results(results: Dict, registry) -> int:
|
||||
"""
|
||||
Apply detected repo information to models in registry
|
||||
|
||||
Args:
|
||||
results: Results from scan_models_for_repo()
|
||||
registry: UserModelRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of models updated
|
||||
"""
|
||||
updated_count = 0
|
||||
|
||||
for model, repo_id, repo_type in results["detected"]:
|
||||
success = registry.update_model(model.name, {"repo_id": repo_id, "repo_type": repo_type})
|
||||
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
return updated_count
|
||||
111
kt-kernel/python/cli/utils/run_configs.py
Normal file
111
kt-kernel/python/cli/utils/run_configs.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Configuration save/load for kt run command.
|
||||
|
||||
Manages saved run configurations bound to specific models.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
import yaml
|
||||
|
||||
|
||||
CONFIG_FILE = Path.home() / ".ktransformers" / "run_configs.yaml"
|
||||
|
||||
|
||||
class RunConfigManager:
|
||||
"""Manager for saved run configurations."""
|
||||
|
||||
def __init__(self):
|
||||
self.config_file = CONFIG_FILE
|
||||
self._ensure_config_file()
|
||||
|
||||
def _ensure_config_file(self):
|
||||
"""Ensure config file exists."""
|
||||
if not self.config_file.exists():
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_data({"version": "1.0", "configs": {}})
|
||||
|
||||
def _load_data(self) -> Dict:
|
||||
"""Load raw config data."""
|
||||
try:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {"version": "1.0", "configs": {}}
|
||||
except Exception:
|
||||
return {"version": "1.0", "configs": {}}
|
||||
|
||||
def _save_data(self, data: Dict):
|
||||
"""Save raw config data."""
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
def list_configs(self, model_id: str) -> List[Dict[str, Any]]:
|
||||
"""List all saved configs for a model.
|
||||
|
||||
Returns:
|
||||
List of config dicts with 'config_name' and other fields.
|
||||
"""
|
||||
data = self._load_data()
|
||||
configs = data.get("configs", {}).get(model_id, [])
|
||||
return configs if isinstance(configs, list) else []
|
||||
|
||||
def save_config(self, model_id: str, config: Dict[str, Any]):
|
||||
"""Save a configuration for a model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to bind config to
|
||||
config: Configuration dict with all run parameters
|
||||
"""
|
||||
data = self._load_data()
|
||||
|
||||
if "configs" not in data:
|
||||
data["configs"] = {}
|
||||
|
||||
if model_id not in data["configs"]:
|
||||
data["configs"][model_id] = []
|
||||
|
||||
# Add timestamp
|
||||
config["created_at"] = datetime.now().isoformat()
|
||||
|
||||
# Append config
|
||||
data["configs"][model_id].append(config)
|
||||
|
||||
self._save_data(data)
|
||||
|
||||
def delete_config(self, model_id: str, config_index: int) -> bool:
|
||||
"""Delete a saved configuration.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
config_index: Index of config to delete (0-based)
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
data = self._load_data()
|
||||
|
||||
if model_id not in data.get("configs", {}):
|
||||
return False
|
||||
|
||||
configs = data["configs"][model_id]
|
||||
if config_index < 0 or config_index >= len(configs):
|
||||
return False
|
||||
|
||||
configs.pop(config_index)
|
||||
self._save_data(data)
|
||||
return True
|
||||
|
||||
def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific saved configuration.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
config_index: Index of config to get (0-based)
|
||||
|
||||
Returns:
|
||||
Config dict or None if not found
|
||||
"""
|
||||
configs = self.list_configs(model_id)
|
||||
if config_index < 0 or config_index >= len(configs):
|
||||
return None
|
||||
return configs[config_index]
|
||||
1084
kt-kernel/python/cli/utils/run_interactive.py
Normal file
1084
kt-kernel/python/cli/utils/run_interactive.py
Normal file
File diff suppressed because it is too large
Load Diff
459
kt-kernel/python/cli/utils/tuna_engine.py
Normal file
459
kt-kernel/python/cli/utils/tuna_engine.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Tuna engine for auto-tuning GPU experts configuration.
|
||||
|
||||
Automatically finds the maximum viable num-gpu-experts through binary search
|
||||
by testing actual server launches with different configurations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kt_kernel.cli.utils.console import console, print_error, print_info, print_warning
|
||||
|
||||
|
||||
def get_num_experts(model_path: Path) -> int:
|
||||
"""
|
||||
Get the number of experts per layer from model config.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model directory
|
||||
|
||||
Returns:
|
||||
Number of experts per layer
|
||||
|
||||
Raises:
|
||||
ValueError: If config.json not found or num_experts field missing
|
||||
"""
|
||||
config_file = model_path / "config.json"
|
||||
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"config.json not found in {model_path}")
|
||||
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse config.json: {e}")
|
||||
|
||||
# Different models may use different field names
|
||||
possible_keys = [
|
||||
"num_experts_per_tok", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
"n_routed_experts", # Qwen
|
||||
"num_experts", # Generic
|
||||
]
|
||||
|
||||
for key in possible_keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
|
||||
raise ValueError(f"Cannot find num_experts field in {config_file}. " f"Tried: {', '.join(possible_keys)}")
|
||||
|
||||
|
||||
def detect_oom(log_line: Optional[str]) -> bool:
|
||||
"""
|
||||
Detect OOM (Out Of Memory) errors from log output.
|
||||
|
||||
Args:
|
||||
log_line: A line from server output
|
||||
|
||||
Returns:
|
||||
True if OOM detected, False otherwise
|
||||
"""
|
||||
if log_line is None:
|
||||
return False
|
||||
|
||||
log_lower = log_line.lower()
|
||||
|
||||
oom_patterns = [
|
||||
"cuda out of memory",
|
||||
"out of memory",
|
||||
"outofmemoryerror",
|
||||
"oom",
|
||||
"failed to allocate",
|
||||
"cumemalloc failed",
|
||||
"cumemallocasync failed",
|
||||
"allocation failed",
|
||||
]
|
||||
|
||||
return any(pattern in log_lower for pattern in oom_patterns)
|
||||
|
||||
|
||||
def test_config(
|
||||
num_gpu_experts: int,
|
||||
model_path: Path,
|
||||
config: dict,
|
||||
verbose: bool = False,
|
||||
) -> tuple[bool, float]:
|
||||
"""
|
||||
Test if a configuration with given num_gpu_experts works.
|
||||
|
||||
Args:
|
||||
num_gpu_experts: Number of GPU experts to test
|
||||
model_path: Path to the model
|
||||
config: Configuration dict with all parameters
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
(success: bool, elapsed_time: float)
|
||||
- success: True if server starts and inference works
|
||||
- elapsed_time: Time taken for the test
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use random port to avoid conflicts
|
||||
test_port = random.randint(30000, 40000)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model",
|
||||
str(model_path),
|
||||
"--port",
|
||||
str(test_port),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--tensor-parallel-size",
|
||||
str(config["tensor_parallel_size"]),
|
||||
"--kt-num-gpu-experts",
|
||||
str(num_gpu_experts),
|
||||
"--max-total-tokens",
|
||||
str(config["max_total_tokens"]),
|
||||
]
|
||||
|
||||
# Add kt-kernel options
|
||||
if config.get("weights_path"):
|
||||
cmd.extend(["--kt-weight-path", str(config["weights_path"])])
|
||||
else:
|
||||
cmd.extend(["--kt-weight-path", str(model_path)])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--kt-cpuinfer",
|
||||
str(config.get("cpu_threads", 64)),
|
||||
"--kt-threadpool-count",
|
||||
str(config.get("numa_nodes", 2)),
|
||||
"--kt-method",
|
||||
config.get("kt_method", "AMXINT4"),
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(config.get("kt_gpu_prefill_threshold", 4096)),
|
||||
]
|
||||
)
|
||||
|
||||
# Add other SGLang options
|
||||
if config.get("attention_backend"):
|
||||
cmd.extend(["--attention-backend", config["attention_backend"]])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
str(config.get("mem_fraction_static", 0.98)),
|
||||
"--chunked-prefill-size",
|
||||
str(config.get("chunked_prefill_size", 4096)),
|
||||
"--max-running-requests",
|
||||
str(config.get("max_running_requests", 1)), # Use 1 for faster testing
|
||||
"--watchdog-timeout",
|
||||
str(config.get("watchdog_timeout", 3000)),
|
||||
"--enable-mixed-chunk",
|
||||
"--enable-p2p-check",
|
||||
]
|
||||
)
|
||||
|
||||
# Add disable-shared-experts-fusion if specified
|
||||
if config.get("disable_shared_experts_fusion"):
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add extra args
|
||||
if config.get("extra_args"):
|
||||
cmd.extend(config["extra_args"])
|
||||
|
||||
if verbose:
|
||||
console.print(f"[dim]Command: {' '.join(cmd)}[/dim]")
|
||||
|
||||
# Start process
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=config.get("env"),
|
||||
)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_error(f"Failed to start process: {e}")
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Monitor process output
|
||||
timeout = 60 # Maximum 60 seconds to wait
|
||||
server_ready = False
|
||||
|
||||
try:
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if process has output
|
||||
if process.poll() is not None:
|
||||
# Process exited
|
||||
if verbose:
|
||||
print_warning("Process exited early")
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Read output line (non-blocking)
|
||||
try:
|
||||
line = process.stdout.readline()
|
||||
if not line:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if verbose:
|
||||
console.print(f"[dim]{line.rstrip()}[/dim]")
|
||||
|
||||
# Fast OOM detection
|
||||
if detect_oom(line):
|
||||
if verbose:
|
||||
print_warning(f"OOM detected: {line.rstrip()}")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Check for startup success
|
||||
if "Uvicorn running" in line or "Application startup complete" in line:
|
||||
server_ready = True
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_warning(f"Error reading output: {e}")
|
||||
break
|
||||
|
||||
if not server_ready:
|
||||
# Timeout or failed to start
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return False, time.time() - start_time
|
||||
|
||||
# Server is ready, test inference
|
||||
success = test_inference(test_port, verbose=verbose)
|
||||
|
||||
# Cleanup
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
return success, time.time() - start_time
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# User cancelled
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
raise
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_error(f"Test failed with exception: {e}")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=2)
|
||||
except:
|
||||
try:
|
||||
process.kill()
|
||||
except:
|
||||
pass
|
||||
return False, time.time() - start_time
|
||||
|
||||
|
||||
def test_inference(port: int, verbose: bool = False) -> bool:
|
||||
"""
|
||||
Test if the server can handle a simple inference request.
|
||||
|
||||
Args:
|
||||
port: Server port
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
True if inference succeeds, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Wait a bit for server to be fully ready
|
||||
time.sleep(2)
|
||||
|
||||
# Try to import OpenAI client
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
if verbose:
|
||||
print_warning("OpenAI package not available, skipping inference test")
|
||||
return True # Assume success if we can't test
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f"http://127.0.0.1:{port}/v1",
|
||||
api_key="test",
|
||||
)
|
||||
|
||||
# Send a simple test request
|
||||
response = client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=1,
|
||||
temperature=0,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Check if we got a valid response
|
||||
success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None
|
||||
|
||||
if verbose:
|
||||
if success:
|
||||
print_info(f"Inference test passed: {response.choices[0].message.content}")
|
||||
else:
|
||||
print_warning("Inference test failed: no valid response")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_warning(f"Inference test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def find_max_gpu_experts(
|
||||
model_path: Path,
|
||||
config: dict,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Binary search to find the maximum viable num_gpu_experts.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
config: Configuration dict
|
||||
verbose: Whether to show detailed logs
|
||||
|
||||
Returns:
|
||||
Maximum number of GPU experts that works
|
||||
"""
|
||||
# Get number of experts from model config
|
||||
try:
|
||||
num_experts = get_num_experts(model_path)
|
||||
except ValueError as e:
|
||||
print_error(str(e))
|
||||
raise
|
||||
|
||||
console.print()
|
||||
console.print(f"Binary search range: [0, {num_experts}]")
|
||||
console.print()
|
||||
|
||||
left, right = 0, num_experts
|
||||
result = 0
|
||||
iteration = 0
|
||||
total_iterations = math.ceil(math.log2(num_experts + 1))
|
||||
|
||||
while left <= right:
|
||||
iteration += 1
|
||||
mid = (left + right) // 2
|
||||
|
||||
console.print(f"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... ", end="")
|
||||
|
||||
success, elapsed = test_config(mid, model_path, config, verbose=verbose)
|
||||
|
||||
if success:
|
||||
console.print(f"[green]✓ OK[/green] ({elapsed:.1f}s)")
|
||||
result = mid
|
||||
left = mid + 1
|
||||
else:
|
||||
console.print(f"[red]✗ FAILED[/red] ({elapsed:.1f}s)")
|
||||
right = mid - 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_tuna(
|
||||
model_path: Path,
|
||||
tensor_parallel_size: int,
|
||||
max_total_tokens: int,
|
||||
kt_method: str,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""
|
||||
Run tuna auto-tuning to find optimal num_gpu_experts.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
tensor_parallel_size: Tensor parallel size
|
||||
max_total_tokens: Maximum total tokens
|
||||
kt_method: KT quantization method
|
||||
verbose: Whether to show detailed logs
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Optimal num_gpu_experts value
|
||||
|
||||
Raises:
|
||||
ValueError: If tuning fails completely
|
||||
"""
|
||||
# Prepare configuration
|
||||
config = {
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"max_total_tokens": max_total_tokens,
|
||||
"kt_method": kt_method,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Run binary search
|
||||
try:
|
||||
result = find_max_gpu_experts(model_path, config, verbose=verbose)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
print_warning("Tuning cancelled by user")
|
||||
raise
|
||||
|
||||
console.print()
|
||||
|
||||
# Check if even 0 doesn't work
|
||||
if result == 0:
|
||||
console.print("[yellow]Testing if gpu-experts=0 is viable...[/yellow]")
|
||||
success, _ = test_config(0, model_path, config, verbose=verbose)
|
||||
|
||||
if not success:
|
||||
# Even 0 doesn't work
|
||||
console.print()
|
||||
print_error("Failed to start server even with all experts on CPU (gpu-experts=0)")
|
||||
console.print()
|
||||
console.print("[bold]Possible reasons:[/bold]")
|
||||
console.print(" • Insufficient GPU memory for base model layers")
|
||||
console.print(" • max-total-tokens is too large for available VRAM")
|
||||
console.print(" • Tensor parallel configuration issue")
|
||||
console.print()
|
||||
console.print("[bold]Suggestions:[/bold]")
|
||||
console.print(f" • Reduce --max-total-tokens (current: {max_total_tokens})")
|
||||
console.print(f" • Reduce --tensor-parallel-size (current: {tensor_parallel_size})")
|
||||
console.print(" • Use more GPUs or GPUs with more VRAM")
|
||||
console.print(" • Try a smaller model")
|
||||
console.print()
|
||||
raise ValueError("Minimum GPU memory requirements not met")
|
||||
else:
|
||||
# 0 works but nothing more
|
||||
console.print()
|
||||
print_warning("All experts will run on CPU (gpu-experts=0). " "Performance will be limited by CPU speed.")
|
||||
|
||||
return result
|
||||
302
kt-kernel/python/cli/utils/user_model_registry.py
Normal file
302
kt-kernel/python/cli/utils/user_model_registry.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
User Model Registry
|
||||
|
||||
Manages user-registered models in ~/.ktransformers/user_models.yaml
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
import yaml
|
||||
|
||||
|
||||
# Constants
|
||||
USER_MODELS_FILE = Path.home() / ".ktransformers" / "user_models.yaml"
|
||||
REGISTRY_VERSION = "1.0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserModel:
|
||||
"""Represents a user-registered model"""
|
||||
|
||||
name: str # User-editable name (default: folder name)
|
||||
path: str # Absolute path to model directory
|
||||
format: str # "safetensors" | "gguf"
|
||||
id: Optional[str] = None # Unique UUID for this model (auto-generated if None)
|
||||
repo_type: Optional[str] = None # "huggingface" | "modelscope" | None
|
||||
repo_id: Optional[str] = None # e.g., "deepseek-ai/DeepSeek-V3"
|
||||
sha256_status: str = "not_checked" # "not_checked" | "checking" | "passed" | "failed" | "no_repo"
|
||||
gpu_model_ids: Optional[List[str]] = None # For llamafile/AMX: list of GPU model UUIDs to run with
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
last_verified: Optional[str] = None # ISO format datetime
|
||||
# MoE information (cached from analyze_moe_model)
|
||||
is_moe: Optional[bool] = None # True if MoE model, False if non-MoE, None if not analyzed
|
||||
moe_num_experts: Optional[int] = None # Total number of experts (for MoE models)
|
||||
moe_num_experts_per_tok: Optional[int] = None # Number of active experts per token (for MoE models)
|
||||
# AMX quantization metadata (for format == "amx")
|
||||
amx_source_model: Optional[str] = None # Name of the source MoE model that was quantized
|
||||
amx_quant_method: Optional[str] = None # "int4" | "int8"
|
||||
amx_numa_nodes: Optional[int] = None # Number of NUMA nodes used for quantization
|
||||
|
||||
def __post_init__(self):
|
||||
"""Ensure ID is set after initialization"""
|
||||
if self.id is None:
|
||||
import uuid
|
||||
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for YAML serialization"""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "UserModel":
|
||||
"""Create from dictionary loaded from YAML"""
|
||||
return cls(**data)
|
||||
|
||||
def path_exists(self) -> bool:
|
||||
"""Check if model path still exists"""
|
||||
return Path(self.path).exists()
|
||||
|
||||
|
||||
class UserModelRegistry:
|
||||
"""Manages the user model registry"""
|
||||
|
||||
def __init__(self, registry_file: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the registry
|
||||
|
||||
Args:
|
||||
registry_file: Path to the registry YAML file (default: USER_MODELS_FILE)
|
||||
"""
|
||||
self.registry_file = registry_file or USER_MODELS_FILE
|
||||
self.models: List[UserModel] = []
|
||||
self.version = REGISTRY_VERSION
|
||||
|
||||
# Ensure directory exists
|
||||
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing registry
|
||||
self.load()
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load models from YAML file"""
|
||||
if not self.registry_file.exists():
|
||||
# Initialize empty registry
|
||||
self.models = []
|
||||
self.save() # Create the file
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.registry_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if not data:
|
||||
self.models = []
|
||||
return
|
||||
|
||||
# Load version
|
||||
self.version = data.get("version", REGISTRY_VERSION)
|
||||
|
||||
# Load models
|
||||
models_data = data.get("models", [])
|
||||
self.models = [UserModel.from_dict(m) for m in models_data]
|
||||
|
||||
# Migrate: ensure all models have UUIDs (for backward compatibility)
|
||||
needs_save = False
|
||||
for model in self.models:
|
||||
if model.id is None:
|
||||
import uuid
|
||||
|
||||
model.id = str(uuid.uuid4())
|
||||
needs_save = True
|
||||
|
||||
if needs_save:
|
||||
self.save()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load user model registry: {e}")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save models to YAML file"""
|
||||
data = {"version": self.version, "models": [m.to_dict() for m in self.models]}
|
||||
|
||||
try:
|
||||
with open(self.registry_file, "w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save user model registry: {e}")
|
||||
|
||||
def add_model(self, model: UserModel) -> None:
|
||||
"""
|
||||
Add a model to the registry
|
||||
|
||||
Args:
|
||||
model: UserModel instance to add
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same name already exists
|
||||
"""
|
||||
if self.check_name_conflict(model.name):
|
||||
raise ValueError(f"Model with name '{model.name}' already exists")
|
||||
|
||||
self.models.append(model)
|
||||
self.save()
|
||||
|
||||
def remove_model(self, name: str) -> bool:
|
||||
"""
|
||||
Remove a model from the registry
|
||||
|
||||
Args:
|
||||
name: Name of the model to remove
|
||||
|
||||
Returns:
|
||||
True if model was removed, False if not found
|
||||
"""
|
||||
original_count = len(self.models)
|
||||
self.models = [m for m in self.models if m.name != name]
|
||||
|
||||
if len(self.models) < original_count:
|
||||
self.save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_model(self, name: str, updates: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a model's attributes
|
||||
|
||||
Args:
|
||||
name: Name of the model to update
|
||||
updates: Dictionary of attributes to update
|
||||
|
||||
Returns:
|
||||
True if model was updated, False if not found
|
||||
"""
|
||||
model = self.get_model(name)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
# Update attributes
|
||||
for key, value in updates.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
self.save()
|
||||
return True
|
||||
|
||||
def get_model(self, name: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get a model by name
|
||||
|
||||
Args:
|
||||
name: Name of the model
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.name == name:
|
||||
return model
|
||||
return None
|
||||
|
||||
def get_model_by_id(self, model_id: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get a model by its unique ID
|
||||
|
||||
Args:
|
||||
model_id: UUID of the model
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.id == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
def list_models(self) -> List[UserModel]:
|
||||
"""
|
||||
List all models
|
||||
|
||||
Returns:
|
||||
List of all UserModel instances
|
||||
"""
|
||||
return self.models.copy()
|
||||
|
||||
def find_by_path(self, path: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Find a model by its path
|
||||
|
||||
Args:
|
||||
path: Model directory path
|
||||
|
||||
Returns:
|
||||
UserModel instance or None if not found
|
||||
"""
|
||||
# Normalize paths for comparison
|
||||
search_path = str(Path(path).resolve())
|
||||
|
||||
for model in self.models:
|
||||
model_path = str(Path(model.path).resolve())
|
||||
if model_path == search_path:
|
||||
return model
|
||||
return None
|
||||
|
||||
def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if a name conflicts with existing models
|
||||
|
||||
Args:
|
||||
name: Name to check
|
||||
exclude_name: Optional name to exclude from check (for rename operations)
|
||||
|
||||
Returns:
|
||||
True if conflict exists, False otherwise
|
||||
"""
|
||||
for model in self.models:
|
||||
if model.name == name and model.name != exclude_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def refresh_status(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Check all models and identify missing ones
|
||||
|
||||
Returns:
|
||||
Dictionary with 'valid' and 'missing' lists of model names
|
||||
"""
|
||||
valid = []
|
||||
missing = []
|
||||
|
||||
for model in self.models:
|
||||
if model.path_exists():
|
||||
valid.append(model.name)
|
||||
else:
|
||||
missing.append(model.name)
|
||||
|
||||
return {"valid": valid, "missing": missing}
|
||||
|
||||
def get_model_count(self) -> int:
|
||||
"""Get total number of registered models"""
|
||||
return len(self.models)
|
||||
|
||||
def suggest_name(self, base_name: str) -> str:
|
||||
"""
|
||||
Suggest a unique name based on base_name
|
||||
|
||||
Args:
|
||||
base_name: Base name to derive from
|
||||
|
||||
Returns:
|
||||
A unique name (may have suffix like -2, -3 etc.)
|
||||
"""
|
||||
if not self.check_name_conflict(base_name):
|
||||
return base_name
|
||||
|
||||
counter = 2
|
||||
while True:
|
||||
candidate = f"{base_name}-{counter}"
|
||||
if not self.check_name_conflict(candidate):
|
||||
return candidate
|
||||
counter += 1
|
||||
Reference in New Issue
Block a user