kt-cli enhancement (#1834)

* [feat]: redesign kt run interactive configuration with i18n support

- Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port)
- Add configuration save/load system (~/.ktransformers/run_configs.yaml)
- Add i18n support for kt chat (en/zh translations)
- Add universal input validators with auto-retry and Chinese comma support
- Add port availability checker with auto-suggestion
- Add parser configuration (--tool-call-parser, --reasoning-parser)
- Remove tuna command and clean up redundant files
- Fix: variable reference bug in run.py, filter to show only MoE models

* [feat]: unify model selection UI and enable shared experts fusion by default

- Unify kt run model selection table with kt model list display
  * Add Total size, MoE Size, Repo, and SHA256 status columns
  * Use consistent formatting and styling
  * Improve user decision-making with more information

- Enable --disable-shared-experts-fusion by default
  * Change default value from False to True
  * Users can still override with --enable-shared-experts-fusion

* [feat]: improve kt chat with performance metrics and better CJK support

- Add performance metrics display after each response
  * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token)
  * Accurate input/output token counts using model tokenizer
  * Fallback to estimation if tokenizer unavailable
  * Metrics shown in dim style (not prominent)

- Fix Chinese character input issues
  * Replace Prompt.ask() with console.input() for better CJK support
  * Fixes backspace deletion showing half-characters

- Suppress NumPy subnormal warnings
  * Filter "The value of the smallest subnormal" warnings
  * Cleaner CLI output on certain hardware environments

* [fix]: correct TTFT measurement in kt chat

- Move start_time initialization before API call
- Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms
- Now correctly measures time from request sent to first token received

* [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案

* [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力

* [docs]: 添加 Clawdbot 飞书接入教程链接

* [feat]: improve CLI table display, model verification, and chat experience

- Add sequence number (#) column to all model tables by default
- Filter kt edit to show only MoE GPU models (exclude AMX)
- Extend kt model verify to check *.json and *.py files in addition to weights
- Fix re-verification bug where repaired files caused false failures
- Suppress tokenizer debug output in kt chat token counting

* [fix]: fix cpu cores.

---------

Co-authored-by: skqliao <skqliao@gmail.com>
This commit is contained in:
Oql
2026-02-04 16:44:54 +08:00
committed by GitHub
parent 4f64665758
commit 56cbd69ac4
23 changed files with 10327 additions and 781 deletions

View File

@@ -96,9 +96,9 @@ def chat(
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
"""
if not HAS_OPENAI:
print_error("OpenAI Python SDK is required for chat functionality.")
print_error(t("chat_openai_required"))
console.print()
console.print("Install it with:")
console.print(t("chat_install_hint"))
console.print(" pip install openai")
raise typer.Exit(1)
@@ -114,10 +114,10 @@ def chat(
console.print()
console.print(
Panel.fit(
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
f"[bold cyan]{t('chat_title')}[/bold cyan]\n\n"
f"{t('chat_server')}: [yellow]{final_host}:{final_port}[/yellow]\n"
f"{t('chat_temperature')}: [cyan]{temperature}[/cyan] | {t('chat_max_tokens')}: [cyan]{max_tokens}[/cyan]\n\n"
f"[dim]{t('chat_help_hint')}[/dim]",
border_style="cyan",
)
)
@@ -152,31 +152,44 @@ def chat(
)
# Test connection
print_info("Connecting to server...")
print_info(t("chat_connecting"))
models = client.models.list()
available_models = [m.id for m in models.data]
if not available_models:
print_error("No models available on server")
print_error(t("chat_no_models"))
raise typer.Exit(1)
# Select model
if model:
if model not in available_models:
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
print_warning(t("chat_model_not_found", model=model, available=", ".join(available_models)))
selected_model = available_models[0]
else:
selected_model = model
else:
selected_model = available_models[0]
print_success(f"Connected to model: {selected_model}")
print_success(t("chat_connected", model=selected_model))
console.print()
# Load tokenizer for accurate token counting
tokenizer = None
try:
from transformers import AutoTokenizer
# selected_model is the model path
tokenizer = AutoTokenizer.from_pretrained(selected_model, trust_remote_code=True)
console.print(f"[dim]Loaded tokenizer from {selected_model}[/dim]")
console.print()
except Exception as e:
console.print(f"[dim yellow]Warning: Could not load tokenizer, token counts will be estimated[/dim]")
console.print()
except Exception as e:
print_error(f"Failed to connect to server: {e}")
print_error(t("chat_connect_failed", error=str(e)))
console.print()
console.print("Make sure the model server is running:")
console.print(t("chat_server_not_running"))
console.print(" kt run <model>")
raise typer.Exit(1)
@@ -201,12 +214,12 @@ def chat(
# Main chat loop
try:
while True:
# Get user input
# Get user input - use console.input() for better CJK character support
try:
user_input = Prompt.ask("[bold green]You[/bold green]")
user_input = console.input(f"[bold green]{t('chat_user_prompt')}[/bold green]: ")
except (EOFError, KeyboardInterrupt):
console.print()
print_info("Goodbye!")
print_info(t("chat_goodbye"))
break
if not user_input.strip():
@@ -224,15 +237,19 @@ def chat(
# Generate response
console.print()
console.print("[bold cyan]Assistant[/bold cyan]")
console.print(f"[bold cyan]{t('chat_assistant_prompt')}[/bold cyan]")
try:
if stream:
# Streaming response
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
response_content = _stream_response(
client, selected_model, messages, temperature, max_tokens, tokenizer
)
else:
# Non-streaming response
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
response_content = _generate_response(
client, selected_model, messages, temperature, max_tokens, tokenizer
)
# Add assistant response to history
messages.append({"role": "assistant", "content": response_content})
@@ -240,7 +257,7 @@ def chat(
console.print()
except Exception as e:
print_error(f"Error generating response: {e}")
print_error(t("chat_generation_error", error=str(e)))
# Remove the user message that caused the error
messages.pop()
continue
@@ -252,12 +269,12 @@ def chat(
except KeyboardInterrupt:
console.print()
console.print()
print_info("Chat interrupted. Goodbye!")
print_info(t("chat_interrupted"))
# Final history save
if save_history and messages:
_save_history(history_file, messages, selected_model)
console.print(f"[dim]History saved to: {history_file}[/dim]")
console.print(f"[dim]{t('chat_history_saved', path=str(history_file))}[/dim]")
console.print()
@@ -267,12 +284,22 @@ def _stream_response(
messages: list,
temperature: float,
max_tokens: int,
tokenizer=None,
) -> str:
"""Generate streaming response and display in real-time."""
import time
response_content = ""
reasoning_content = ""
# Performance tracking
first_token_time = None
chunk_count = 0
try:
# Start timing before sending request
start_time = time.time()
stream = client.chat.completions.create(
model=model,
messages=messages,
@@ -282,33 +309,120 @@ def _stream_response(
)
for chunk in stream:
delta = chunk.choices[0].delta
reasoning_delta = getattr(delta, "reasoning_content", None)
if reasoning_delta:
reasoning_content += reasoning_delta
console.print(reasoning_delta, end="", style="dim")
if delta.content:
content = delta.content
response_content += content
console.print(content, end="")
delta = chunk.choices[0].delta if chunk.choices else None
if delta:
reasoning_delta = getattr(delta, "reasoning_content", None)
if reasoning_delta:
if first_token_time is None:
first_token_time = time.time()
reasoning_content += reasoning_delta
console.print(reasoning_delta, end="", style="dim")
chunk_count += 1
if delta.content:
if first_token_time is None:
first_token_time = time.time()
content = delta.content
response_content += content
console.print(content, end="")
chunk_count += 1
console.print() # Newline after streaming
# Display performance metrics
end_time = time.time()
if first_token_time and chunk_count > 0:
ttft = first_token_time - start_time
total_time = end_time - start_time
# Calculate TPOT based on chunks
if chunk_count > 1:
generation_time = total_time - ttft
tpot = generation_time / (chunk_count - 1)
else:
tpot = 0
# Calculate accurate token counts using tokenizer
if tokenizer:
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
output_tokens = _count_tokens_with_tokenizer(
[{"role": "assistant", "content": response_content}], tokenizer
)
token_prefix = ""
else:
# Fallback to estimation
input_tokens = _estimate_tokens(messages)
output_tokens = _estimate_tokens([{"role": "assistant", "content": response_content}])
token_prefix = "~"
# Build metrics display
metrics = f"[dim]Total: {total_time*1000:.0f}ms | TTFT: {ttft*1000:.0f}ms"
if tpot > 0:
metrics += f" | TPOT: {tpot*1000:.1f}ms"
metrics += f" | In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}"
metrics += "[/dim]"
console.print(metrics)
except Exception as e:
raise Exception(f"Streaming error: {e}")
return response_content
def _count_tokens_with_tokenizer(messages: list, tokenizer) -> int:
"""Count tokens accurately using the model's tokenizer."""
try:
# Concatenate all message content
text = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
# Simple format: role + content
text += f"{role}: {content}\n"
# Encode and count tokens - suppress any debug output from custom tokenizers
import os
import sys
from contextlib import redirect_stdout, redirect_stderr
with open(os.devnull, "w") as devnull:
with redirect_stdout(devnull), redirect_stderr(devnull):
tokens = tokenizer.encode(text, add_special_tokens=True)
return len(tokens)
except Exception:
# Fallback to estimation if tokenizer fails
return _estimate_tokens(messages)
def _estimate_tokens(messages: list) -> int:
"""Estimate token count for messages (rough approximation)."""
total_chars = 0
for msg in messages:
content = msg.get("content", "")
total_chars += len(content)
# Rough estimation:
# - English: ~4 chars per token
# - Chinese: ~1.5 chars per token
# Use 2.5 as average
return max(1, int(total_chars / 2.5))
def _generate_response(
client: "OpenAI",
model: str,
messages: list,
temperature: float,
max_tokens: int,
tokenizer=None,
) -> str:
"""Generate non-streaming response."""
import time
try:
start_time = time.time()
response = client.chat.completions.create(
model=model,
messages=messages,
@@ -317,12 +431,36 @@ def _generate_response(
stream=False,
)
end_time = time.time()
total_time = end_time - start_time
content = response.choices[0].message.content
# Display as markdown
md = Markdown(content)
console.print(md)
# Calculate accurate token counts using tokenizer
if tokenizer:
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
output_tokens = _count_tokens_with_tokenizer([{"role": "assistant", "content": content}], tokenizer)
token_prefix = ""
else:
# Fallback to API usage or estimation
input_tokens = response.usage.prompt_tokens if response.usage else _estimate_tokens(messages)
output_tokens = (
response.usage.completion_tokens
if response.usage
else _estimate_tokens([{"role": "assistant", "content": content}])
)
token_prefix = "" if response.usage else "~"
# Display performance metrics
console.print(
f"[dim]Time: {total_time*1000:.0f}ms | "
f"In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}[/dim]"
)
return content
except Exception as e:
@@ -335,20 +473,14 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
if cmd in ["/quit", "/exit", "/q"]:
console.print()
print_info("Goodbye!")
print_info(t("chat_goodbye"))
return False
elif cmd in ["/help", "/h"]:
console.print()
console.print(
Panel(
"[bold]Available Commands:[/bold]\n\n"
"/help, /h - Show this help message\n"
"/quit, /exit, /q - Exit chat\n"
"/clear, /c - Clear conversation history\n"
"/history, /hist - Show conversation history\n"
"/info, /i - Show current settings\n"
"/retry, /r - Regenerate last response",
f"[bold]{t('chat_help_title')}[/bold]\n\n{t('chat_help_content')}",
title="Help",
border_style="cyan",
)
@@ -359,19 +491,19 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
elif cmd in ["/clear", "/c"]:
messages.clear()
console.print()
print_success("Conversation history cleared")
print_success(t("chat_history_cleared"))
console.print()
return True
elif cmd in ["/history", "/hist"]:
console.print()
if not messages:
print_info("No conversation history")
print_info(t("chat_no_history"))
else:
console.print(
Panel(
_format_history(messages),
title=f"History ({len(messages)} messages)",
title=t("chat_history_title", count=len(messages)),
border_style="cyan",
)
)
@@ -382,10 +514,7 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
console.print()
console.print(
Panel(
f"[bold]Current Settings:[/bold]\n\n"
f"Temperature: [cyan]{temperature}[/cyan]\n"
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
f"Messages: [cyan]{len(messages)}[/cyan]",
f"[bold]{t('chat_info_title')}[/bold]\n\n{t('chat_info_content', temperature=temperature, max_tokens=max_tokens, messages=len(messages))}",
title="Info",
border_style="cyan",
)
@@ -397,16 +526,16 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
# Remove last assistant response
messages.pop()
print_info("Retrying last response...")
print_info(t("chat_retrying"))
console.print()
else:
print_warning("No previous response to retry")
print_warning(t("chat_no_retry"))
console.print()
return True
else:
print_warning(f"Unknown command: {command}")
console.print("[dim]Type /help for available commands[/dim]")
print_warning(t("chat_unknown_command", command=command))
console.print(f"[dim]{t('chat_unknown_hint')}[/dim]")
console.print()
return True

File diff suppressed because it is too large Load Diff

View File

@@ -35,12 +35,12 @@ class QuantMethod(str, Enum):
def quant(
model: str = typer.Argument(
...,
model: Optional[str] = typer.Argument(
None,
help="Model name or path to quantize",
),
method: QuantMethod = typer.Option(
QuantMethod.INT4,
method: Optional[QuantMethod] = typer.Option(
None,
"--method",
"-m",
help="Quantization method",
@@ -51,8 +51,8 @@ def quant(
"-o",
help="Output path for quantized weights",
),
input_type: str = typer.Option(
"fp8",
input_type: Optional[str] = typer.Option(
None,
"--input-type",
"-i",
help="Input weight type (fp8, fp16, bf16)",
@@ -72,6 +72,11 @@ def quant(
"--no-merge",
help="Don't merge safetensor files",
),
gpu: bool = typer.Option(
False,
"--gpu",
help="Use GPU for conversion (faster)",
),
yes: bool = typer.Option(
False,
"--yes",
@@ -79,54 +84,231 @@ def quant(
help="Skip confirmation prompts",
),
) -> None:
"""Quantize model weights for CPU inference."""
settings = get_settings()
console.print()
"""Quantize model weights for CPU inference.
# Resolve input path
input_path = _resolve_input_path(model, settings)
if input_path is None:
print_error(t("quant_input_not_found", path=model))
If no model is specified, interactive mode will be activated.
"""
settings = get_settings()
# Check if we should use interactive mode
# Interactive mode triggers when: no model, or missing critical parameters
needs_interactive = model is None or method is None or cpu_threads is None or numa_nodes is None
is_interactive = False
if needs_interactive and sys.stdin.isatty():
# Use interactive configuration (includes verification in Step 1.5)
from kt_kernel.cli.utils.quant_interactive import interactive_quant_config
console.print()
console.print(f"[bold cyan]═══ {t('quant_interactive_title')} ═══[/bold cyan]")
console.print()
console.print(f"[yellow]{t('quant_new_model_notice')}[/yellow]")
console.print()
config = interactive_quant_config()
if config is None:
# User cancelled
raise typer.Exit(0)
# Extract configuration
model_obj = config["model"]
model = model_obj.id
input_path = Path(model_obj.path)
method = QuantMethod(config["method"])
input_type = config["input_type"]
cpu_threads = config["cpu_threads"]
numa_nodes = config["numa_nodes"]
output = config["output_path"]
gpu = config["use_gpu"]
is_interactive = True
console.print()
print_success(t("quant_config_complete"))
console.print()
else:
# Non-interactive mode - require model parameter
if model is None:
print_error("Model argument is required in non-interactive mode")
console.print()
console.print("Usage: kt quant <model>")
console.print(" Or: kt quant (for interactive mode)")
raise typer.Exit(1)
# Set defaults for optional parameters
method = method or QuantMethod.INT4
input_type = input_type or "fp8"
console.print()
# Resolve input path
input_path = _resolve_input_path(model, settings)
if input_path is None:
print_error(t("quant_input_not_found", path=model))
raise typer.Exit(1)
# Pre-quantization verification (only in non-interactive mode)
# Interactive mode already did verification in interactive_quant_config()
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
user_registry = UserModelRegistry()
user_model_obj = user_registry.find_by_path(str(input_path))
if user_model_obj and user_model_obj.format == "safetensors":
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
# Get user model info for both modes (needed later for registering quantized model)
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
user_registry = UserModelRegistry()
user_model_obj = user_registry.find_by_path(str(input_path))
# Validate that it's a MoE model (not AMX or GGUF)
from kt_kernel.cli.commands.model import is_amx_weights
# Check if it's AMX (already quantized)
is_amx, _ = is_amx_weights(str(input_path))
if is_amx:
print_error("Cannot quantize AMX models (already quantized)")
console.print()
console.print(f" The model at {input_path} is already in AMX format.")
raise typer.Exit(1)
print_info(t("quant_input_path", path=str(input_path)))
# Check if it's a MoE model
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
# Resolve output path
if output is None:
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
print_info(t("quant_output_path", path=str(output)))
print_info(t("quant_method", method=method.value.upper()))
# Detect CPU configuration
cpu = detect_cpu_info()
final_cpu_threads = cpu_threads or cpu.cores
final_numa_nodes = numa_nodes or cpu.numa_nodes
print_info(f"CPU threads: {final_cpu_threads}")
print_info(f"NUMA nodes: {final_numa_nodes}")
# Check if output exists
if output.exists():
print_warning(f"Output path already exists: {output}")
moe_result = None # Store for later use when registering quantized model
try:
moe_result = analyze_moe_model(str(input_path), use_cache=True)
if not moe_result or not moe_result.get("is_moe"):
print_error("Only MoE models can be quantized to AMX format")
console.print()
console.print(f" The model at {input_path} is not a MoE model.")
console.print(" AMX quantization is designed for MoE models (e.g., DeepSeek-V3).")
raise typer.Exit(1)
except Exception as e:
print_warning(f"Could not detect MoE information: {e}")
console.print()
if not yes:
if not confirm("Overwrite?", default=False):
if not confirm("Continue quantization anyway?", default=False):
raise typer.Exit(1)
# Detect CPU configuration and resolve output path (only needed in non-interactive mode)
if not is_interactive:
print_info(t("quant_input_path", path=str(input_path)))
# Detect CPU configuration (needed for output path)
cpu = detect_cpu_info()
final_cpu_threads = cpu_threads or cpu.cores
final_numa_nodes = numa_nodes or cpu.numa_nodes
# Resolve output path
if output is None:
# Priority: paths.weights > paths.models[0] > model's parent directory
weights_dir = settings.weights_dir
if weights_dir and weights_dir.exists():
# Use configured weights directory (highest priority)
output = weights_dir / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
else:
# Use first model storage path
model_paths = settings.get_model_paths()
if model_paths and model_paths[0].exists():
output = model_paths[0] / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
else:
# Fallback to model's parent directory
output = input_path.parent / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
print_info(t("quant_output_path", path=str(output)))
print_info(t("quant_method", method=method.value.upper()))
print_info(t("quant_cpu_threads", threads=final_cpu_threads))
print_info(t("quant_numa_nodes", nodes=final_numa_nodes))
# Calculate space requirements
console.print()
console.print(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]")
console.print()
# Calculate source model size
try:
total_bytes = sum(f.stat().st_size for f in input_path.glob("*.safetensors") if f.is_file())
source_size_gb = total_bytes / (1024**3)
except Exception:
source_size_gb = 0.0
# Estimate quantized size
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
quant_bits = {"int4": 4, "int8": 8}
input_bit = input_bits.get(input_type, 16)
quant_bit = quant_bits.get(method.value, 4)
ratio = quant_bit / input_bit
estimated_size_gb = source_size_gb * ratio
# Check available space
import shutil
try:
check_path = output.parent if not output.exists() else output
while not check_path.exists() and check_path != check_path.parent:
check_path = check_path.parent
stat = shutil.disk_usage(check_path)
available_gb = stat.free / (1024**3)
except Exception:
available_gb = 0.0
is_sufficient = available_gb >= (estimated_size_gb * 1.2)
console.print(f" {t('quant_source_size'):<26} {source_size_gb:.2f} GB")
console.print(f" {t('quant_estimated_size'):<26} {estimated_size_gb:.2f} GB")
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
console.print()
if not is_sufficient:
required_with_buffer = estimated_size_gb * 1.2
print_warning(t("quant_insufficient_space"))
console.print()
console.print(f" {t('quant_required_space'):<26} {required_with_buffer:.2f} GB")
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
console.print(f" {t('quant_shortage'):<26} {required_with_buffer - available_gb:.2f} GB")
console.print()
console.print(f" {t('quant_may_fail')}")
console.print()
if not yes:
if not confirm(t("quant_continue_anyway"), default=False):
raise typer.Abort()
console.print()
# Check if output exists and generate unique name
if output.exists():
print_warning(t("quant_output_exists", path=str(output)))
console.print()
# Generate unique name by adding suffix
original_name = output.name
parent_dir = output.parent
counter = 2
while output.exists():
new_name = f"{original_name}-{counter}"
output = parent_dir / new_name
counter += 1
print_success(t("quant_using_unique", path=str(output)))
console.print()
# Confirm (only show if not using --yes flag)
if not yes:
console.print()
print_warning(t("quant_time_warning"))
console.print()
if not confirm(t("prompt_continue")):
raise typer.Abort()
# Confirm
if not yes:
console.print()
console.print("[bold]Quantization Settings:[/bold]")
console.print(f" Input: {input_path}")
console.print(f" Output: {output}")
console.print(f" Method: {method.value.upper()}")
console.print(f" Input type: {input_type}")
console.print()
print_warning("Quantization may take 30-60 minutes depending on model size.")
console.print()
if not confirm(t("prompt_continue")):
raise typer.Abort()
else:
# Interactive mode: cpu_threads and numa_nodes already set
final_cpu_threads = cpu_threads
final_numa_nodes = numa_nodes
# Find conversion script
kt_kernel_path = _find_kt_kernel_path()
@@ -141,37 +323,145 @@ def quant(
# Build command
cmd = [
sys.executable, str(script_path),
"--input-path", str(input_path),
"--input-type", input_type,
"--output", str(output),
"--quant-method", method.value,
"--cpuinfer-threads", str(final_cpu_threads),
"--threadpool-count", str(final_numa_nodes),
sys.executable,
str(script_path),
"--input-path",
str(input_path),
"--input-type",
input_type,
"--output",
str(output),
"--quant-method",
method.value,
"--cpuinfer-threads",
str(final_cpu_threads),
"--threadpool-count",
str(final_numa_nodes),
]
if no_merge:
cmd.append("--no-merge-safetensor")
if gpu:
cmd.append("--gpu")
# Run quantization
console.print()
print_step(t("quant_starting"))
console.print()
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
console.print()
console.print("[dim]" + "=" * 80 + "[/dim]")
console.print()
try:
process = subprocess.run(cmd)
# Run with real-time stdout/stderr output
import os
import time
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1" # Disable Python output buffering
# Record start time
start_time = time.time()
process = subprocess.run(
cmd,
stdout=None, # Inherit parent's stdout (real-time output)
stderr=None, # Inherit parent's stderr (real-time output)
env=env,
)
# Calculate elapsed time
elapsed_time = time.time() - start_time
hours = int(elapsed_time // 3600)
minutes = int((elapsed_time % 3600) // 60)
seconds = int(elapsed_time % 60)
console.print()
console.print("[dim]" + "=" * 80 + "[/dim]")
console.print()
if process.returncode == 0:
console.print()
print_success(t("quant_complete"))
console.print()
# Display elapsed time
if hours > 0:
time_str = f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
time_str = f"{minutes}m {seconds}s"
else:
time_str = f"{seconds}s"
console.print(f" [cyan]{t('quant_time_elapsed')} {time_str}[/cyan]")
console.print()
console.print(f" Quantized weights saved to: {output}")
console.print()
console.print(" Use with:")
console.print(f" kt run {model} --weights-path {output}")
console.print()
# Auto-register the quantized model
try:
from kt_kernel.cli.utils.user_model_registry import UserModel
# Generate model name from output path
base_name = output.name
suggested_name = user_registry.suggest_name(base_name)
# Determine MoE information and source model name
if user_model_obj:
is_moe_val = user_model_obj.is_moe
num_experts = user_model_obj.moe_num_experts
num_active = user_model_obj.moe_num_experts_per_tok
repo_type_val = user_model_obj.repo_type
repo_id_val = user_model_obj.repo_id
source_model_name = user_model_obj.name # Store source model name
elif moe_result:
is_moe_val = moe_result.get("is_moe", True)
num_experts = moe_result.get("num_experts")
num_active = moe_result.get("num_experts_per_tok")
repo_type_val = None
repo_id_val = None
source_model_name = input_path.name # Use folder name as fallback
else:
is_moe_val = None
num_experts = None
num_active = None
repo_type_val = None
repo_id_val = None
source_model_name = input_path.name # Use folder name as fallback
# Create new model entry (AMX format uses "safetensors" format, detected by is_amx_weights())
new_model = UserModel(
name=suggested_name,
path=str(output),
format="safetensors", # AMX files are safetensors format
repo_type=repo_type_val,
repo_id=repo_id_val,
sha256_status="not_checked", # AMX weights don't need verification
# Inherit MoE information from source model
is_moe=is_moe_val,
moe_num_experts=num_experts,
moe_num_experts_per_tok=num_active,
# AMX quantization metadata
amx_source_model=source_model_name,
amx_quant_method=method.value, # "int4" or "int8"
amx_numa_nodes=final_numa_nodes,
)
user_registry.add_model(new_model)
console.print()
print_success(t("quant_registered", name=suggested_name))
console.print()
console.print(f" {t('quant_view_with')} [cyan]kt model list[/cyan]")
console.print(f" {t('quant_use_with')} [cyan]kt run {suggested_name}[/cyan]")
console.print()
except Exception as e:
# Non-fatal error - quantization succeeded but registration failed
console.print()
print_warning(t("quant_register_failed", error=str(e)))
console.print()
console.print(f" {t('quant_use_with')}")
console.print(f" kt run {model} --weights-path {output}")
console.print()
else:
print_error(f"Quantization failed with exit code {process.returncode}")
raise typer.Exit(process.returncode)
@@ -221,6 +511,7 @@ def _find_kt_kernel_path() -> Optional[Path]:
"""Find the kt-kernel installation path."""
try:
import kt_kernel
return Path(kt_kernel.__file__).parent.parent
except ImportError:
pass

View File

@@ -28,7 +28,7 @@ from kt_kernel.cli.utils.console import (
prompt_choice,
)
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
@click.command(
@@ -120,8 +120,6 @@ def run(
# Handle disable/enable shared experts fusion flags
if enable_shared_experts_fusion:
disable_shared_experts_fusion = False
elif disable_shared_experts_fusion is None:
disable_shared_experts_fusion = None
# Convert Path objects from click
model_path_obj = Path(model_path) if model_path else None
@@ -214,266 +212,250 @@ def _run_impl(
raise typer.Exit(1)
settings = get_settings()
registry = get_registry()
user_registry = UserModelRegistry()
console.print()
# Check if we should use interactive mode
# Interactive mode triggers when:
# 1. No model specified, OR
# 2. Model specified but missing critical parameters (gpu_experts, tensor_parallel_size, etc.)
use_interactive = False
# If no model specified, show interactive selection
if model is None:
model = _interactive_model_selection(registry, settings)
if model is None:
use_interactive = True
elif (
gpu_experts is None
or tensor_parallel_size is None
or cpu_threads is None
or numa_nodes is None
or max_total_tokens is None
):
# Model specified but some parameters missing - use interactive
use_interactive = True
if use_interactive and sys.stdin.isatty():
# Use new interactive configuration flow
from kt_kernel.cli.utils.run_interactive import interactive_run_config
console.print()
console.print("[bold cyan]═══ Interactive Run Configuration ═══[/bold cyan]")
console.print()
config = interactive_run_config()
if config is None:
# User cancelled
raise typer.Exit(0)
# Step 1: Detect hardware
print_step(t("run_detecting_hardware"))
gpus = detect_gpus()
cpu = detect_cpu_info()
ram = detect_ram_gb()
# Extract configuration from new format
user_model_obj = config["model"]
model = user_model_obj.id
resolved_model_path = Path(config["model_path"])
resolved_weights_path = Path(config["weights_path"])
if gpus:
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
if len(gpus) > 1:
gpu_info += f" + {len(gpus) - 1} more"
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
# Extract parameters
gpu_experts = config["gpu_experts"]
cpu_threads = config["cpu_threads"]
numa_nodes = config["numa_nodes"]
tensor_parallel_size = config["tp_size"]
# Get kt-method and other method-specific settings
kt_method = config["kt_method"]
# KV cache settings (may be None for non-raw methods)
max_total_tokens = config.get("kv_cache", 32768)
chunked_prefill_size = config.get("chunk_prefill", 32768)
kt_gpu_prefill_threshold = config.get("gpu_prefill_threshold", 500)
# Memory settings
mem_fraction_static = config["mem_fraction_static"]
# Parser settings (optional)
tool_call_parser = config.get("tool_call_parser")
reasoning_parser = config.get("reasoning_parser")
# Server settings
host = config.get("host", "0.0.0.0")
port = config.get("port", 30000)
# Set CUDA_VISIBLE_DEVICES for selected GPUs
selected_gpus = config["selected_gpus"]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id) for gpu_id in selected_gpus)
# Detect hardware for parameter resolution (needed for resolve() function later)
gpus = detect_gpus()
cpu = detect_cpu_info()
console.print()
print_info(f"[green]✓[/green] Configuration complete")
console.print()
else:
print_warning(t("doctor_gpu_not_found"))
gpu_info = "None"
# Non-interactive mode - use traditional flow
console.print()
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
print_info(t("run_ram_info", total=int(ram)))
# Initialize variables that may have been set by interactive mode
# These will be None in non-interactive mode and will use defaults via resolve()
# Step 2: Resolve model
console.print()
print_step(t("run_checking_model"))
# If no model specified, show old interactive selection
if model is None:
model = _interactive_model_selection(user_registry, settings)
if model is None:
raise typer.Exit(0)
model_info = None
resolved_model_path = model_path
# Detect hardware (needed for defaults)
gpus = detect_gpus()
cpu = detect_cpu_info()
ram = detect_ram_gb()
# Check if model is a path
if Path(model).exists():
resolved_model_path = Path(model)
print_info(t("run_model_path", path=str(resolved_model_path)))
# Try to infer model type from path to use default configurations
# Check directory name against known models
dir_name = resolved_model_path.name.lower()
for registered_model in registry.list_all():
# Check if directory name matches model name or aliases
if dir_name == registered_model.name.lower():
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
for alias in registered_model.aliases:
if dir_name == alias.lower() or alias.lower() in dir_name:
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
if model_info:
break
# Also check HuggingFace repo format (org--model)
if not model_info:
for registered_model in registry.list_all():
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
if repo_slug in dir_name or dir_name in repo_slug:
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
if not model_info:
print_warning("Could not detect model type from path. Using default parameters.")
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
else:
# Search in registry
matches = registry.search(model)
if not matches:
print_error(t("run_model_not_found", name=model))
console.print()
console.print("Available models:")
for m in registry.list_all()[:5]:
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
raise typer.Exit(1)
if len(matches) == 1:
model_info = matches[0]
if gpus:
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
if len(gpus) > 1:
gpu_info += f" + {len(gpus) - 1} more"
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
else:
# Multiple matches - prompt user
console.print()
print_info(t("run_multiple_matches"))
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
selected = prompt_choice(t("run_select_model"), choices)
idx = choices.index(selected)
model_info = matches[idx]
print_warning(t("doctor_gpu_not_found"))
gpu_info = "None"
# Find model path
if model_path is None:
resolved_model_path = _find_model_path(model_info, settings)
if resolved_model_path is None:
print_error(t("run_model_not_found", name=model_info.name))
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
print_info(t("run_ram_info", total=int(ram)))
# Step 2: Resolve model
console.print()
print_step(t("run_checking_model"))
user_model_obj = None
resolved_model_path = model_path
# Check if model is a path
if Path(model).exists():
resolved_model_path = Path(model)
print_info(t("run_model_path", path=str(resolved_model_path)))
# Try to find in user registry by path
user_model_obj = user_registry.find_by_path(str(resolved_model_path))
if user_model_obj:
print_info(f"Using registered model: {user_model_obj.name}")
else:
print_warning("Using unregistered model path. Consider adding it with 'kt model add'")
else:
# Search in user registry by name
user_model_obj = user_registry.get_model(model)
if not user_model_obj:
print_error(t("run_model_not_found", name=model))
console.print()
console.print(
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
)
# Show available models
all_models = user_registry.list_models()
if all_models:
console.print("Available registered models:")
for m in all_models[:5]:
console.print(f" - {m.name}")
if len(all_models) > 5:
console.print(f" ... and {len(all_models) - 5} more")
else:
console.print("No models registered yet.")
console.print()
console.print(f"Add your model with: [cyan]kt model add /path/to/model[/cyan]")
console.print(f"Or scan for models: [cyan]kt model scan[/cyan]")
raise typer.Exit(1)
print_info(t("run_model_path", path=str(resolved_model_path)))
# Use model path from registry
resolved_model_path = Path(user_model_obj.path)
# Step 3: Check quantized weights (only if explicitly requested)
resolved_weights_path = None
# Verify path exists
if not resolved_model_path.exists():
print_error(f"Model path does not exist: {resolved_model_path}")
console.print()
console.print(f"Run 'kt model refresh' to check all models")
raise typer.Exit(1)
# Only use quantized weights if explicitly specified by user
if weights_path is not None:
# User explicitly specified weights path
resolved_weights_path = weights_path
if not resolved_weights_path.exists():
print_error(t("run_weights_not_found"))
console.print(f" Path: {resolved_weights_path}")
print_info(t("run_model_path", path=str(resolved_model_path)))
# Step 2.5: Pre-run verification (optional integrity check)
if user_model_obj and user_model_obj.format == "safetensors":
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
pre_operation_verification(user_model_obj, user_registry, operation_name="running")
# Step 3: Check quantized weights (only if explicitly requested)
resolved_weights_path = None
# Only use quantized weights if explicitly specified by user
if weights_path is not None:
# User explicitly specified weights path
resolved_weights_path = weights_path
if not resolved_weights_path.exists():
print_error(t("run_weights_not_found"))
console.print(f" Path: {resolved_weights_path}")
raise typer.Exit(1)
print_info(f"Using quantized weights: {resolved_weights_path}")
elif quantize:
# User requested quantization
console.print()
print_step(t("run_quantizing"))
# TODO: Implement quantization
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
raise typer.Exit(1)
print_info(f"Using quantized weights: {resolved_weights_path}")
elif quantize:
# User requested quantization
console.print()
print_step(t("run_quantizing"))
# TODO: Implement quantization
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
raise typer.Exit(1)
else:
# Default: use original precision model without quantization
console.print()
print_info("Using original precision model (no quantization)")
else:
# Default: use original precision model without quantization
console.print()
print_info("Using original precision model (no quantization)")
# Step 4: Build command
# Resolve all parameters (CLI > model defaults > config > auto-detect)
final_host = host or settings.get("server.host", "0.0.0.0")
final_port = port or settings.get("server.port", 30000)
# Helper to resolve parameter with fallback chain: CLI > config > default
def resolve(cli_val, config_key, default):
if cli_val is not None:
return cli_val
config_val = settings.get(config_key)
return config_val if config_val is not None else default
# Get defaults from model info if available
model_defaults = model_info.default_params if model_info else {}
# Server configuration
final_host = resolve(host, "server.host", "0.0.0.0")
final_port = resolve(port, "server.port", 30000)
# Determine tensor parallel size first (needed for GPU expert calculation)
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
# Check if explicitly specified by user or configuration
explicitly_specified = (
tensor_parallel_size # CLI argument (highest priority)
or model_defaults.get("tensor-parallel-size") # Model defaults
or settings.get("inference.tensor_parallel_size") # Config file
# Tensor parallel size: CLI > config > auto-detect from GPUs
final_tensor_parallel_size = resolve(
tensor_parallel_size, "inference.tensor_parallel_size", len(gpus) if gpus else 1
)
if explicitly_specified:
# Use explicitly specified value
requested_tensor_parallel_size = explicitly_specified
else:
# Auto-detect from GPUs, considering model's max constraint
detected_gpu_count = len(gpus) if gpus else 1
if model_info and model_info.max_tensor_parallel_size is not None:
# Automatically limit to model's maximum to use as many GPUs as possible
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
else:
requested_tensor_parallel_size = detected_gpu_count
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
final_tensor_parallel_size = requested_tensor_parallel_size
if model_info and model_info.max_tensor_parallel_size is not None:
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
console.print()
print_warning(
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
f"Reducing to {model_info.max_tensor_parallel_size}."
)
final_tensor_parallel_size = model_info.max_tensor_parallel_size
# CPU/GPU configuration with smart defaults
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
total_threads = cpu.cores * cpu.numa_nodes
final_cpu_threads = (
cpu_threads
or model_defaults.get("kt-cpuinfer")
or settings.get("inference.cpu_threads")
or int(total_threads * 0.8)
)
# kt-threadpool-count: default to NUMA node count
final_numa_nodes = (
numa_nodes
or model_defaults.get("kt-threadpool-count")
or settings.get("inference.numa_nodes")
or cpu.numa_nodes
)
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
if gpu_experts is not None:
# User explicitly set it
final_gpu_experts = gpu_experts
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
# Use model-specific computation function (only if GPUs detected)
vram_per_gpu = gpus[0].vram_gb
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
console.print()
print_info(
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
)
else:
# Fall back to defaults
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
total_threads = cpu.threads # Use logical threads instead of physical cores
final_cpu_threads = resolve(cpu_threads, "inference.cpu_threads", int(total_threads * 0.8))
final_numa_nodes = resolve(numa_nodes, "inference.numa_nodes", cpu.numa_nodes)
final_gpu_experts = resolve(gpu_experts, "inference.gpu_experts", 1)
# KT-kernel options
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
final_kt_gpu_prefill_threshold = (
kt_gpu_prefill_threshold
or model_defaults.get("kt-gpu-prefill-token-threshold")
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
)
final_kt_method = resolve(kt_method, "inference.kt_method", "AMXINT4")
final_kt_gpu_prefill_threshold = resolve(kt_gpu_prefill_threshold, "inference.kt_gpu_prefill_token_threshold", 4096)
# SGLang options
final_attention_backend = (
attention_backend
or model_defaults.get("attention-backend")
or settings.get("inference.attention_backend", "triton")
)
final_max_total_tokens = (
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
)
final_max_running_requests = (
max_running_requests
or model_defaults.get("max-running-requests")
or settings.get("inference.max_running_requests", 32)
)
final_chunked_prefill_size = (
chunked_prefill_size
or model_defaults.get("chunked-prefill-size")
or settings.get("inference.chunked_prefill_size", 4096)
)
final_mem_fraction_static = (
mem_fraction_static
or model_defaults.get("mem-fraction-static")
or settings.get("inference.mem_fraction_static", 0.98)
)
final_watchdog_timeout = (
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
)
final_served_model_name = (
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
)
final_attention_backend = resolve(attention_backend, "inference.attention_backend", "flashinfer")
final_max_total_tokens = resolve(max_total_tokens, "inference.max_total_tokens", 40000)
final_max_running_requests = resolve(max_running_requests, "inference.max_running_requests", 32)
final_chunked_prefill_size = resolve(chunked_prefill_size, "inference.chunked_prefill_size", 4096)
final_mem_fraction_static = resolve(mem_fraction_static, "inference.mem_fraction_static", 0.98)
final_watchdog_timeout = resolve(watchdog_timeout, "inference.watchdog_timeout", 3000)
final_served_model_name = resolve(served_model_name, "inference.served_model_name", "")
# Performance flags
if disable_shared_experts_fusion is not None:
final_disable_shared_experts_fusion = disable_shared_experts_fusion
elif "disable-shared-experts-fusion" in model_defaults:
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
else:
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
final_disable_shared_experts_fusion = resolve(
disable_shared_experts_fusion, "inference.disable_shared_experts_fusion", True
)
# Pass all model default params to handle any extra parameters
extra_params = model_defaults if model_info else {}
# Pass extra CLI parameters
extra_params = {}
# Parser parameters (from interactive mode or None in non-interactive mode)
final_tool_call_parser = None
final_reasoning_parser = None
if "tool_call_parser" in locals() and tool_call_parser:
final_tool_call_parser = tool_call_parser
if "reasoning_parser" in locals() and reasoning_parser:
final_reasoning_parser = reasoning_parser
cmd = _build_sglang_command(
model_path=resolved_model_path,
weights_path=resolved_weights_path,
model_info=model_info,
host=final_host,
port=final_port,
gpu_experts=final_gpu_experts,
@@ -490,6 +472,8 @@ def _run_impl(
watchdog_timeout=final_watchdog_timeout,
served_model_name=final_served_model_name,
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
tool_call_parser=final_tool_call_parser,
reasoning_parser=final_reasoning_parser,
settings=settings,
extra_model_params=extra_params,
extra_cli_args=extra_cli_args,
@@ -508,11 +492,9 @@ def _run_impl(
console.print()
print_step("Configuration")
# Model info
if model_info:
console.print(f" Model: [bold]{model_info.name}[/bold]")
else:
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
# Display model name
model_display_name = user_model_obj.name if user_model_obj else resolved_model_path.name
console.print(f" Model: [bold]{model_display_name}[/bold]")
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
@@ -572,88 +554,13 @@ def _run_impl(
raise typer.Exit(1)
def _find_model_path(model_info: ModelInfo, settings, max_depth: int = 3) -> Optional[Path]:
"""Find the model path on disk by searching all configured model paths.
Args:
model_info: Model information to search for
settings: Settings instance
max_depth: Maximum depth to search within each model path (default: 3)
Returns:
Path to the model directory, or None if not found
"""
model_paths = settings.get_model_paths()
# Generate possible names to search for
possible_names = [
model_info.name,
model_info.name.lower(),
model_info.name.replace(" ", "-"),
model_info.hf_repo.split("/")[-1],
model_info.hf_repo.replace("/", "--"),
]
# Add alias-based names
for alias in model_info.aliases:
possible_names.append(alias)
possible_names.append(alias.lower())
# Search in all configured model directories
for models_dir in model_paths:
if not models_dir.exists():
continue
# Search recursively up to max_depth
for depth in range(max_depth):
for name in possible_names:
if depth == 0:
# Direct children: models_dir / name
search_paths = [models_dir / name]
else:
# Nested: use rglob to find directories matching the name
search_paths = list(models_dir.rglob(name))
for path in search_paths:
if path.exists() and (path / "config.json").exists():
return path
return None
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
"""Find the quantized weights path on disk by searching all configured paths."""
model_paths = settings.get_model_paths()
weights_dir = settings.weights_dir
# Check common patterns
base_names = [
model_info.name,
model_info.name.lower(),
model_info.hf_repo.split("/")[-1],
]
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
# Prepare search directories
search_dirs = [weights_dir] if weights_dir else []
search_dirs.extend(model_paths)
for base in base_names:
for suffix in suffixes:
for dir_path in search_dirs:
if dir_path:
path = dir_path / f"{base}{suffix}"
if path.exists():
return path
return None
# Dead code removed: _find_model_path() and _find_weights_path()
# These functions were part of the old builtin model system
def _build_sglang_command(
model_path: Path,
weights_path: Optional[Path],
model_info: Optional[ModelInfo],
host: str,
port: int,
gpu_experts: int,
@@ -670,6 +577,8 @@ def _build_sglang_command(
watchdog_timeout: int,
served_model_name: str,
disable_shared_experts_fusion: bool,
tool_call_parser: Optional[str],
reasoning_parser: Optional[str],
settings,
extra_model_params: Optional[dict] = None, # New parameter for additional params
extra_cli_args: Optional[list[str]] = None, # Extra args from CLI to pass to sglang
@@ -700,9 +609,6 @@ def _build_sglang_command(
elif cpu_threads > 0 or gpu_experts > 1:
# CPU offloading configured - use kt-kernel
use_kt_kernel = True
elif model_info and model_info.type == "moe":
# MoE model - likely needs kt-kernel for expert offloading
use_kt_kernel = True
if use_kt_kernel:
# Add kt-weight-path: use quantized weights if available, otherwise use model path
@@ -723,6 +629,7 @@ def _build_sglang_command(
kt_method,
"--kt-gpu-prefill-token-threshold",
str(kt_gpu_prefill_threshold),
"--kt-enable-dynamic-expert-update", # Enable dynamic expert updates
]
)
@@ -757,6 +664,16 @@ def _build_sglang_command(
if disable_shared_experts_fusion:
cmd.append("--disable-shared-experts-fusion")
# Add FP8 backend if using FP8 method
if "FP8" in kt_method.upper():
cmd.extend(["--fp8-gemm-backend", "triton"])
# Add parsers if specified
if tool_call_parser:
cmd.extend(["--tool-call-parser", tool_call_parser])
if reasoning_parser:
cmd.extend(["--reasoning-parser", reasoning_parser])
# Add any extra parameters from model defaults that weren't explicitly handled
if extra_model_params:
# List of parameters already handled above
@@ -801,30 +718,31 @@ def _build_sglang_command(
return cmd
def _interactive_model_selection(registry, settings) -> Optional[str]:
def _interactive_model_selection(user_registry, settings) -> Optional[str]:
"""Show interactive model selection interface.
Returns:
Selected model name or None if cancelled.
"""
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt
from kt_kernel.cli.i18n import get_lang
# Get all user models
all_models = user_registry.list_models()
lang = get_lang()
# Find local models first
local_models = registry.find_local_models()
# Get all registered models
all_models = registry.list_all()
if not all_models:
console.print()
print_warning("No models registered.")
console.print()
console.print(f" Add models with: [cyan]kt model scan[/cyan]")
console.print(f" Or manually: [cyan]kt model add /path/to/model[/cyan]")
console.print()
return None
console.print()
console.print(
Panel.fit(
t("run_select_model_title"),
"Select a model to run",
border_style="cyan",
)
)
@@ -834,54 +752,30 @@ def _interactive_model_selection(registry, settings) -> Optional[str]:
choices = []
choice_map = {} # index -> model name
# Section 1: Local models (downloaded)
if local_models:
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
console.print()
for i, (model_info, path) in enumerate(local_models, 1):
desc = model_info.description_zh if lang == "zh" else model_info.description
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
console.print(f" [dim]{short_desc}[/dim]")
console.print(f" [dim]{path}[/dim]")
choices.append(str(i))
choice_map[str(i)] = model_info.name
console.print()
# Section 2: All registered models (for reference)
start_idx = len(local_models) + 1
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
# Show all user models
console.print(f"[bold green]Available Models:[/bold green]")
console.print()
# Filter out already shown local models
local_model_names = {m.name for m, _ in local_models}
for i, model_info in enumerate(all_models, start_idx):
if model_info.name in local_model_names:
continue
desc = model_info.description_zh if lang == "zh" else model_info.description
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
console.print(f" [dim]{short_desc}[/dim]")
console.print(f" [dim]{model_info.hf_repo}[/dim]")
for i, model in enumerate(all_models, 1):
# Check if path exists
path_status = "" if model.path_exists() else "✗ Missing"
console.print(f" [cyan][{i}][/cyan] [bold]{model.name}[/bold] [{path_status}]")
console.print(f" [dim]{model.format} - {model.path}[/dim]")
choices.append(str(i))
choice_map[str(i)] = model_info.name
choice_map[str(i)] = model.name
console.print()
# Add cancel option
cancel_idx = str(len(choices) + 1)
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]Cancel[/dim]")
choices.append(cancel_idx)
console.print()
# Prompt for selection
try:
selection = Prompt.ask(
t("run_select_model_prompt"),
"Select model",
choices=choices,
default="1" if choices else cancel_idx,
)

View File

@@ -9,7 +9,7 @@ _kt_completion() {
prev="${COMP_WORDS[COMP_CWORD-1]}"
# Main commands
local commands="version run chat quant bench microbench doctor model config sft"
local commands="version run chat quant edit bench microbench doctor model config sft"
# Global options
local global_opts="--help --version"
@@ -36,6 +36,10 @@ _kt_completion() {
local quant_opts="--method --output --help"
COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) )
;;
edit)
local edit_opts="--help"
COMPREPLY=( $(compgen -W "${edit_opts}" -- ${cur}) )
;;
bench|microbench)
local bench_opts="--model --config --help"
COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) )

View File

@@ -190,6 +190,70 @@ MESSAGES: dict[str, dict[str, str]] = {
"quant_progress": "Quantizing...",
"quant_complete": "Quantization complete!",
"quant_input_not_found": "Input model not found at {path}",
"quant_cpu_threads": "CPU threads: {threads}",
"quant_numa_nodes": "NUMA nodes: {nodes}",
"quant_time_warning": "Quantization may take 30-60 minutes depending on model size.",
"quant_disk_analysis": "Disk Space Analysis:",
"quant_source_size": "Source model size:",
"quant_estimated_size": "Estimated output size:",
"quant_available_space": "Available space:",
"quant_insufficient_space": "WARNING: Insufficient disk space!",
"quant_required_space": "Required space (with 20% buffer):",
"quant_shortage": "Shortage:",
"quant_may_fail": "Quantization may fail or produce incomplete files.",
"quant_continue_anyway": "Continue anyway?",
"quant_settings": "Quantization Settings:",
"quant_registered": "Quantized model registered: {name}",
"quant_view_with": "View with:",
"quant_use_with": "Use with:",
"quant_register_failed": "Failed to auto-register model: {error}",
"quant_output_exists": "Output path already exists: {path}",
"quant_using_unique": "Using unique name: {path}",
# Interactive quant
"quant_interactive_title": "Interactive Quantization Configuration",
"quant_new_model_notice": "⚠ Note: Some newer models cannot be quantized yet (conversion script not adapted). Recommended to use the original precision for inference (no weight conversion needed).",
"quant_no_moe_models": "No MoE models found for quantization.",
"quant_only_moe": "Only MoE models (e.g., DeepSeek-V3) can be quantized to AMX format.",
"quant_add_models": "Add models with: {command}",
"quant_moe_available": "MoE Models Available for Quantization:",
"quant_select_model": "Select model to quantize",
"quant_invalid_choice": "Invalid choice",
"quant_step2_method": "Step 2: Quantization Method",
"quant_method_label": "Quantization Method:",
"quant_int4_desc": "INT4",
"quant_int8_desc": "INT8",
"quant_select_method": "Select quantization method",
"quant_input_type_label": "Input Weight Type:",
"quant_fp8_desc": "FP8 (for 8-bit float weights)",
"quant_fp16_desc": "FP16 (for 16-bit float weights)",
"quant_bf16_desc": "BF16 (for Brain Float 16 weights)",
"quant_select_input_type": "Select input type",
"quant_step3_cpu": "Step 3: CPU Configuration",
"quant_cpu_threads_prompt": "CPU Threads (1 to {max})",
"quant_numa_nodes_prompt": "NUMA Nodes (1 to {max})",
"quant_use_gpu_label": "Use GPU for conversion?",
"quant_gpu_speedup": "GPU can significantly speed up the quantization process",
"quant_enable_gpu": "Enable GPU acceleration?",
"quant_step4_output": "Step 4: Output Path",
"quant_default_path": "Default:",
"quant_use_default": "Use default output path?",
"quant_custom_path": "Enter custom output path",
"quant_output_exists_warn": "⚠ Output path already exists: {path}",
"quant_using_unique_name": "→ Using unique name: {path}",
"quant_config_summary": "Configuration Summary",
"quant_summary_model": "Model:",
"quant_summary_method": "Method:",
"quant_summary_input_type": "Input Type:",
"quant_summary_cpu_threads": "CPU Threads:",
"quant_summary_numa": "NUMA Nodes:",
"quant_summary_gpu": "Use GPU:",
"quant_summary_output": "Output Path:",
"quant_start_question": "Start quantization?",
"quant_cancelled": "Cancelled",
"quant_config_complete": "Configuration complete",
"quant_time_elapsed": "Time elapsed:",
"yes": "Yes",
"no": "No",
# SFT command
"sft_mode_train": "Training mode",
"sft_mode_chat": "Chat mode",
@@ -247,6 +311,113 @@ MESSAGES: dict[str, dict[str, str]] = {
"chat_proxy_detected": "Proxy detected in environment",
"chat_proxy_confirm": "Use proxy for connection?",
"chat_proxy_disabled": "Proxy disabled for this session",
"chat_openai_required": "OpenAI Python SDK is required for chat functionality.",
"chat_install_hint": "Install it with:",
"chat_title": "KTransformers Chat",
"chat_server": "Server",
"chat_temperature": "Temperature",
"chat_max_tokens": "Max tokens",
"chat_help_hint": "Type '/help' for commands, '/quit' to exit",
"chat_connecting": "Connecting to server...",
"chat_no_models": "No models available on server",
"chat_model_not_found": "Model '{model}' not found. Available models: {available}",
"chat_connected": "Connected to model: {model}",
"chat_connect_failed": "Failed to connect to server: {error}",
"chat_server_not_running": "Make sure the model server is running:",
"chat_user_prompt": "You",
"chat_assistant_prompt": "Assistant",
"chat_generation_error": "Error generating response: {error}",
"chat_interrupted": "Chat interrupted. Goodbye!",
"chat_history_saved": "History saved to: {path}",
"chat_goodbye": "Goodbye!",
"chat_help_title": "Available Commands:",
"chat_help_content": "/help, /h - Show this help message\n/quit, /exit, /q - Exit chat\n/clear, /c - Clear conversation history\n/history, /hist - Show conversation history\n/info, /i - Show current settings\n/retry, /r - Regenerate last response",
"chat_history_cleared": "Conversation history cleared",
"chat_no_history": "No conversation history",
"chat_history_title": "History ({count} messages)",
"chat_info_title": "Current Settings:",
"chat_info_content": "Temperature: {temperature}\nMax tokens: {max_tokens}\nMessages: {messages}",
"chat_retrying": "Retrying last response...",
"chat_no_retry": "No previous response to retry",
"chat_unknown_command": "Unknown command: {command}",
"chat_unknown_hint": "Type /help for available commands",
# Run Interactive
"run_int_no_moe_models": "No MoE GPU models found.",
"run_int_add_models": "Add models with: kt model scan",
"run_int_list_all": "List all models: kt model list --all",
"run_int_step1_title": "Step 1: Select Model (GPU MoE Models)",
"run_int_select_model": "Select model",
"run_int_step2_title": "Step 2: Select Inference Method",
"run_int_method_raw": "RAW Precision (FP8/FP8_PERCHANNEL/BF16/RAWINT4)",
"run_int_method_amx": "AMX Quantization (INT4/INT8)",
"run_int_method_gguf": "GGUF (Llamafile)",
"run_int_method_saved": "Use Saved Configuration",
"run_int_select_method": "Select inference method",
"run_int_raw_precision": "RAW Precision:",
"run_int_select_precision": "Select precision",
"run_int_amx_method": "AMX Method:",
"run_int_select_amx": "Select AMX method",
"run_int_step3_title": "Step 3: NUMA and CPU Configuration",
"run_int_numa_nodes": "NUMA Nodes (1-{max})",
"run_int_cpu_threads": "CPU Threads per NUMA (1-{max})",
"run_int_amx_warning": "⚠ Warning: AMX INT4/INT8 requires compatible CPU. Check with: kt doctor",
"run_int_step4_title": "Step 4: GPU Experts Configuration",
"run_int_gpu_experts": "GPU Experts per Layer (0-{max})",
"run_int_gpu_experts_info": "Total experts: {total}, Activated per token: {active}",
"run_int_step5_title": "Step 5: KV Cache Configuration",
"run_int_kv_cache_size": "KV Cache Size (tokens)",
"run_int_chunk_prefill": "Enable Chunk Prefill?",
"run_int_chunk_size": "Chunk Prefill Size (tokens)",
"run_int_gpu_prefill_threshold": "GPU Prefill Threshold (tokens)",
"run_int_step6_title": "Step 6: GPU Selection and Tensor Parallelism",
"run_int_available_gpus": "Available GPUs:",
"run_int_gpu_id": "GPU {id}",
"run_int_vram_info": "{name} ({total:.1f}GB total, {free:.1f}GB free)",
"run_int_select_gpus": "Select GPU IDs (comma-separated)",
"run_int_invalid_gpu_range": "All GPU IDs must be between 0 and {max}",
"run_int_tp_size": "TP Size (must be power of 2: 1,2,4,8...)",
"run_int_tp_mismatch": "TP size must match number of selected GPUs ({count})",
"run_int_tp_not_power_of_2": "TP size must be a power of 2",
"run_int_mem_fraction": "Static Memory Fraction (0.0-1.0)",
"run_int_using_saved_mem": "Using saved memory fraction: {fraction}",
"run_int_step7_title": "Step 7: Parser Configuration (Optional)",
"run_int_tool_call_parser": "Tool Call Parser (press Enter to skip)",
"run_int_reasoning_parser": "Reasoning Parser (press Enter to skip)",
"run_int_step8_title": "Step 8: Host and Port Configuration",
"run_int_host": "Host",
"run_int_port": "Port",
"run_int_port_occupied": "⚠ Port {port} is already in use",
"run_int_port_suggestion": "Suggested available port: {port}",
"run_int_use_suggested": "Use suggested port?",
"run_int_saved_configs": "Saved Configurations:",
"run_int_config_name": "Configuration {num}",
"run_int_kt_method": "KT Method:",
"run_int_numa_nodes_label": "NUMA Nodes:",
"run_int_cpu_threads_label": "CPU Threads:",
"run_int_gpu_experts_label": "GPU Experts:",
"run_int_tp_size_label": "TP Size:",
"run_int_mem_fraction_label": "Memory Fraction:",
"run_int_server_label": "Server:",
"run_int_kv_cache_label": "KV Cache:",
"run_int_chunk_prefill_label": "Chunk Prefill:",
"run_int_gpu_prefill_label": "GPU Prefill Thr:",
"run_int_tool_parser_label": "Tool Call Parser:",
"run_int_reasoning_parser_label": "Reasoning Parser:",
"run_int_command_label": "Command:",
"run_int_select_config": "Select configuration",
"run_int_gpu_select_required": "Please select {tp} GPUs (TP size from saved config)",
"run_int_port_check_title": "Port Configuration",
"run_int_port_checking": "Checking port {port} availability...",
"run_int_port_available": "Port {port} is available",
"run_int_saved_config_title": "Saved Configuration",
"run_int_save_config_title": "Save Configuration",
"run_int_save_config_prompt": "Save this configuration for future use?",
"run_int_config_name_prompt": "Configuration name",
"run_int_config_name_default": "Config {timestamp}",
"run_int_config_saved": "Configuration saved: {name}",
"run_int_config_summary": "Configuration Complete",
"run_int_model_label": "Model:",
"run_int_selected_gpus_label": "Selected GPUs:",
# Model command
"model_supported_title": "KTransformers Supported Models",
"model_column_model": "Model",
@@ -282,6 +453,180 @@ MESSAGES: dict[str, dict[str, str]] = {
"model_column_name": "Name",
"model_column_hf_repo": "HuggingFace Repo",
"model_column_aliases": "Aliases",
# Model management - new user registry system
"model_no_registered_models": "No models registered yet.",
"model_scan_hint": "Scan for models: kt model scan",
"model_add_hint": "Add a model: kt model add /path/to/model",
"model_registered_models_title": "Registered Models",
"model_column_format": "Format",
"model_column_repo": "Repository",
"model_column_sha256": "SHA256",
"model_non_moe_hidden_hint": "Detected {count} non-MoE models, use kt model list --all to show all",
"model_usage_title": "Common Operations:",
"model_usage_info": "View details:",
"model_usage_edit": "Edit model:",
"model_usage_verify": "Verify integrity:",
"model_usage_quant": "Quantize model:",
"model_usage_run": "Run model:",
"model_usage_scan": "Scan for models:",
"model_usage_add": "Add model:",
"model_usage_verbose": "View with file details:",
"model_no_storage_paths": "No storage paths configured.",
"model_add_path_hint": "Add a storage path with: kt config set model.storage_paths /path/to/models",
"model_scanning_paths": "Scanning configured storage paths...",
"model_scanning_progress": "Scanning: {path}",
"model_scan_warnings_title": "Warnings",
"model_scan_no_models_found": "No models found in configured paths.",
"model_scan_check_paths_hint": "Check your storage paths: kt config get model.storage_paths",
"model_scan_min_size_hint": "Folders must be ≥{size}GB to be detected as models.",
"model_scan_found_title": "Found {count} new model(s)",
"model_column_path": "Path",
"model_column_size": "Size",
"model_scan_auto_adding": "Auto-adding models...",
"model_added": "Added: {name}",
"model_add_failed": "Failed to add {name}: {error}",
"model_scan_complete": "Scan complete! Added {count} model(s).",
"model_scan_interactive_prompt": "Commands: edit <id> | del <id> | done",
"model_scan_cmd_edit": "Set custom name for model",
"model_scan_cmd_delete": "Skip this model",
"model_scan_cmd_done": "Finish and add models",
"model_scan_marked_skip": "Skipped model #{id}",
"model_scan_invalid_id": "Invalid model ID: {id}",
"model_scan_invalid_command": "Invalid command. Use: edit <id> | del <id> | done",
"model_scan_edit_model": "Edit model {id}",
"model_scan_edit_note": "You can change the model name before adding it to registry",
"model_scan_adding_models": "Adding {count} model(s)...",
"model_scan_next_steps": "Next Steps",
"model_scan_view_hint": "View registered models: kt model list",
"model_scan_edit_hint": "Edit model details: kt model edit <name>",
"model_scan_no_models_added": "No models were added.",
"model_add_path_not_exist": "Error: Path does not exist: {path}",
"model_add_not_directory": "Error: Path is not a directory: {path}",
"model_add_already_registered": "This path is already registered as: {name}",
"model_add_view_hint": "View with: kt model info {name}",
"model_add_scanning": "Scanning model files...",
"model_add_scan_failed": "Failed to scan model: {error}",
"model_add_no_model_files": "No model files found in {path}",
"model_add_supported_formats": "Supported: *.safetensors, *.gguf (folder ≥10GB)",
"model_add_detected": "Detected: {format} format, {size}, {count} file(s)",
"model_add_name_conflict": "Name '{name}' already exists.",
"model_add_prompt_name": "Enter a name for this model",
"model_add_name_exists": "Name already exists. Please choose another name:",
"model_add_configure_repo": "Configure repository information for SHA256 verification?",
"model_add_repo_type_prompt": "Select repository type:",
"model_add_choice": "Choice",
"model_add_repo_id_prompt": "Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)",
"model_add_success": "Successfully added model: {name}",
"model_add_verify_hint": "Verify integrity: kt model verify {name}",
"model_add_edit_later_hint": "Edit details later: kt model edit {name}",
"model_add_failed_generic": "Failed to add model: {error}",
"model_edit_not_found": "Model '{name}' not found.",
"model_edit_list_hint": "List models: kt model list",
"model_edit_current_config": "Current Configuration",
"model_edit_what_to_edit": "What would you like to edit?",
"model_edit_option_name": "Edit name",
"model_edit_option_repo": "Configure repository info",
"model_edit_option_delete": "Delete this model",
"model_edit_option_cancel": "Cancel / Exit",
"model_edit_choice_prompt": "Select option",
"model_edit_new_name": "Enter new name",
"model_edit_name_conflict": "Name '{name}' already exists. Please choose another:",
"model_edit_name_updated": "Name updated: {old}{new}",
"model_edit_repo_type_prompt": "Repository type (or enter to remove repo info):",
"model_edit_repo_remove": "Remove repository info",
"model_edit_repo_id_prompt": "Enter repository ID",
"model_edit_repo_removed": "Repository info removed",
"model_edit_repo_updated": "Repository configured: {repo_type}{repo_id}",
"model_edit_delete_warning": "Delete model '{name}' from registry?",
"model_edit_delete_note": "Note: This only removes the registry entry. Model files in {path} will NOT be deleted.",
"model_edit_delete_confirm": "Confirm deletion?",
"model_edit_deleted": "Model '{name}' deleted from registry",
"model_edit_delete_cancelled": "Deletion cancelled",
"model_edit_cancelled": "Edit cancelled",
# Model edit - Interactive selection
"model_edit_select_title": "Select Model to Edit",
"model_edit_select_model": "Select model",
"model_edit_invalid_choice": "Invalid choice",
"model_edit_no_models": "No models found in registry.",
"model_edit_add_hint_scan": "Add models with:",
"model_edit_add_hint_add": "Or:",
# Model edit - Display
"model_edit_gpu_links": "GPU Links:",
# Model edit - Menu options
"model_edit_manage_gpu_links": "Manage GPU Links",
"model_edit_save_changes": "Save changes",
"model_edit_has_changes": "(has changes)",
"model_edit_no_changes": "(no changes)",
# Model edit - Pending changes messages
"model_edit_name_pending": "Name will be updated when you save changes.",
"model_edit_repo_remove_pending": "Repository info will be removed when you save changes.",
"model_edit_repo_update_pending": "Repository info will be updated when you save changes.",
# Model edit - GPU link management
"model_edit_gpu_links_title": "Manage GPU Links for {name}",
"model_edit_current_gpu_links": "Current GPU links:",
"model_edit_no_gpu_links": "No GPU links configured.",
"model_edit_gpu_options": "Options:",
"model_edit_gpu_add": "Add GPU link",
"model_edit_gpu_remove": "Remove GPU link",
"model_edit_gpu_clear": "Clear all GPU links",
"model_edit_gpu_back": "Back to main menu",
"model_edit_gpu_choose_option": "Choose option",
"model_edit_gpu_none_available": "No GPU models available to link.",
"model_edit_gpu_available_models": "Available GPU models:",
"model_edit_gpu_already_linked": "(already linked)",
"model_edit_gpu_enter_number": "Enter GPU model number to add",
"model_edit_gpu_link_pending": "GPU link will be added when you save changes: {name}",
"model_edit_gpu_already_exists": "This GPU model is already linked.",
"model_edit_gpu_invalid_choice": "Invalid choice.",
"model_edit_gpu_invalid_input": "Invalid input.",
"model_edit_gpu_none_to_remove": "No GPU links to remove.",
"model_edit_gpu_choose_to_remove": "Choose GPU link to remove:",
"model_edit_gpu_enter_to_remove": "Enter number to remove",
"model_edit_gpu_remove_pending": "GPU link will be removed when you save changes: {name}",
"model_edit_gpu_none_to_clear": "No GPU links to clear.",
"model_edit_gpu_clear_confirm": "Remove all GPU links?",
"model_edit_gpu_clear_pending": "All GPU links will be removed when you save changes.",
"model_edit_cancelled_short": "Cancelled.",
# Model edit - Save operation
"model_edit_no_changes_to_save": "No changes to save.",
"model_edit_saving": "Saving changes...",
"model_edit_saved": "Changes saved successfully!",
"model_edit_updated_config": "Updated Configuration:",
"model_edit_repo_changed_warning": "⚠ Repository information has changed.",
"model_edit_verify_hint": "Run [cyan]kt model verify[/cyan] to verify model integrity with SHA256 checksums.",
"model_edit_discard_changes": "Discard unsaved changes?",
"model_info_not_found": "Model '{name}' not found.",
"model_info_list_hint": "List all models: kt model list",
"model_remove_not_found": "Model '{name}' not found.",
"model_remove_list_hint": "List models: kt model list",
"model_remove_warning": "Remove model '{name}' from registry?",
"model_remove_note": "Note: This only removes the registry entry. Model files will NOT be deleted from {path}.",
"model_remove_confirm": "Confirm removal?",
"model_remove_cancelled": "Removal cancelled",
"model_removed": "Model '{name}' removed from registry",
"model_remove_failed": "Failed to remove model: {error}",
"model_refresh_checking": "Checking model paths...",
"model_refresh_all_valid": "All models are valid! ({count} model(s) checked)",
"model_refresh_total": "Total models: {total}",
"model_refresh_missing_found": "Found {count} missing model(s)",
"model_refresh_suggestions": "Suggested Actions",
"model_refresh_remove_hint": "Remove from registry: kt model remove <name>",
"model_refresh_rescan_hint": "Re-scan for models: kt model scan",
"model_verify_not_found": "Model '{name}' not found.",
"model_verify_list_hint": "List models: kt model list",
"model_verify_no_repo": "Model '{name}' has no repository information configured.",
"model_verify_config_hint": "Configure repository: kt model edit {name}",
"model_verify_path_missing": "Model path does not exist: {path}",
"model_verify_starting": "Verifying model integrity...",
"model_verify_progress": "Repository: {repo_type}{repo_id}",
"model_verify_not_implemented": "SHA256 verification not implemented yet",
"model_verify_future_note": "This feature will fetch official SHA256 hashes from {repo_type} and compare with local files.",
"model_verify_passed": "Verification passed! All files match official hashes.",
"model_verify_failed": "Verification failed! {count} file(s) have hash mismatches.",
"model_verify_all_no_repos": "No models have repository information configured.",
"model_verify_all_config_hint": "Configure repos using: kt model edit <name>",
"model_verify_all_found": "Found {count} model(s) with repository info",
"model_verify_all_manual_hint": "Verify specific model: kt model verify <name>",
# Coming soon
"feature_coming_soon": "This feature is coming soon...",
},
@@ -465,6 +810,70 @@ MESSAGES: dict[str, dict[str, str]] = {
"quant_progress": "正在量化...",
"quant_complete": "量化完成!",
"quant_input_not_found": "未找到输入模型: {path}",
"quant_cpu_threads": "CPU 线程数: {threads}",
"quant_numa_nodes": "NUMA 节点数: {nodes}",
"quant_time_warning": "量化可能需要 30-60 分钟,具体取决于模型大小。",
"quant_disk_analysis": "磁盘空间分析:",
"quant_source_size": "源模型大小:",
"quant_estimated_size": "预估输出大小:",
"quant_available_space": "可用空间:",
"quant_insufficient_space": "警告:磁盘空间不足!",
"quant_required_space": "所需空间含20%缓冲):",
"quant_shortage": "不足:",
"quant_may_fail": "量化可能失败或生成不完整的文件。",
"quant_continue_anyway": "仍然继续?",
"quant_settings": "量化设置:",
"quant_registered": "量化模型已注册:{name}",
"quant_view_with": "查看:",
"quant_use_with": "使用:",
"quant_register_failed": "自动注册模型失败:{error}",
"quant_output_exists": "输出路径已存在:{path}",
"quant_using_unique": "使用唯一名称:{path}",
# Interactive quant
"quant_interactive_title": "交互式量化配置",
"quant_new_model_notice": "⚠ 注意:部分新模型暂时无法量化(转换脚本未适配),推荐使用原精度进行推理(无需转换权重)。",
"quant_no_moe_models": "未找到可量化的 MoE 模型。",
"quant_only_moe": "只有 MoE 模型(如 DeepSeek-V3可以被量化为 AMX 格式。",
"quant_add_models": "添加模型:{command}",
"quant_moe_available": "可量化的 MoE 模型:",
"quant_select_model": "选择要量化的模型",
"quant_invalid_choice": "无效选择",
"quant_step2_method": "第 2 步:量化方法",
"quant_method_label": "量化方法:",
"quant_int4_desc": "INT4",
"quant_int8_desc": "INT8",
"quant_select_method": "选择量化方法",
"quant_input_type_label": "输入权重类型:",
"quant_fp8_desc": "FP8适用于 8 位浮点权重)",
"quant_fp16_desc": "FP16适用于 16 位浮点权重)",
"quant_bf16_desc": "BF16适用于 Brain Float 16 权重)",
"quant_select_input_type": "选择输入类型",
"quant_step3_cpu": "第 3 步CPU 配置",
"quant_cpu_threads_prompt": "CPU 线程数1 到 {max}",
"quant_numa_nodes_prompt": "NUMA 节点数1 到 {max}",
"quant_use_gpu_label": "是否使用 GPU 进行转换?",
"quant_gpu_speedup": "GPU 可以显著加快量化速度",
"quant_enable_gpu": "启用 GPU 加速?",
"quant_step4_output": "第 4 步:输出路径",
"quant_default_path": "默认:",
"quant_use_default": "使用默认输出路径?",
"quant_custom_path": "输入自定义输出路径",
"quant_output_exists_warn": "⚠ 输出路径已存在:{path}",
"quant_using_unique_name": "→ 使用唯一名称:{path}",
"quant_config_summary": "配置摘要",
"quant_summary_model": "模型:",
"quant_summary_method": "方法:",
"quant_summary_input_type": "输入类型:",
"quant_summary_cpu_threads": "CPU 线程数:",
"quant_summary_numa": "NUMA 节点数:",
"quant_summary_gpu": "使用 GPU",
"quant_summary_output": "输出路径:",
"quant_start_question": "开始量化?",
"quant_cancelled": "已取消",
"quant_config_complete": "配置完成",
"quant_time_elapsed": "耗时:",
"yes": "",
"no": "",
# SFT command
"sft_mode_train": "训练模式",
"sft_mode_chat": "聊天模式",
@@ -522,6 +931,113 @@ MESSAGES: dict[str, dict[str, str]] = {
"chat_proxy_detected": "检测到环境中存在代理设置",
"chat_proxy_confirm": "是否使用代理连接?",
"chat_proxy_disabled": "已在本次会话中禁用代理",
"chat_openai_required": "聊天功能需要 OpenAI Python SDK。",
"chat_install_hint": "安装命令:",
"chat_title": "KTransformers 对话",
"chat_server": "服务器",
"chat_temperature": "温度",
"chat_max_tokens": "最大 tokens",
"chat_help_hint": "输入 '/help' 查看命令,'/quit' 退出",
"chat_connecting": "正在连接服务器...",
"chat_no_models": "服务器上没有可用模型",
"chat_model_not_found": "未找到模型 '{model}'。可用模型:{available}",
"chat_connected": "已连接到模型:{model}",
"chat_connect_failed": "连接服务器失败:{error}",
"chat_server_not_running": "请确保模型服务器正在运行:",
"chat_user_prompt": "用户",
"chat_assistant_prompt": "助手",
"chat_generation_error": "生成回复时出错:{error}",
"chat_interrupted": "对话已中断。再见!",
"chat_history_saved": "历史记录已保存到:{path}",
"chat_goodbye": "再见!",
"chat_help_title": "可用命令:",
"chat_help_content": "/help, /h - 显示此帮助信息\n/quit, /exit, /q - 退出聊天\n/clear, /c - 清除对话历史\n/history, /hist - 显示对话历史\n/info, /i - 显示当前设置\n/retry, /r - 重新生成上一个回复",
"chat_history_cleared": "对话历史已清除",
"chat_no_history": "暂无对话历史",
"chat_history_title": "历史记录({count} 条消息)",
"chat_info_title": "当前设置:",
"chat_info_content": "温度:{temperature}\n最大 tokens{max_tokens}\n消息数:{messages}",
"chat_retrying": "正在重试上一个回复...",
"chat_no_retry": "没有可重试的回复",
"chat_unknown_command": "未知命令:{command}",
"chat_unknown_hint": "输入 /help 查看可用命令",
# Run Interactive
"run_int_no_moe_models": "未找到 MoE GPU 模型。",
"run_int_add_models": "添加模型kt model scan",
"run_int_list_all": "列出所有模型kt model list --all",
"run_int_step1_title": "第 1 步选择模型GPU MoE 模型)",
"run_int_select_model": "选择模型",
"run_int_step2_title": "第 2 步:选择推理方法",
"run_int_method_raw": "RAW 精度FP8/FP8_PERCHANNEL/BF16/RAWINT4",
"run_int_method_amx": "AMX 量化INT4/INT8",
"run_int_method_gguf": "GGUFLlamafile",
"run_int_method_saved": "使用已保存的配置",
"run_int_select_method": "选择推理方法",
"run_int_raw_precision": "RAW 精度:",
"run_int_select_precision": "选择精度",
"run_int_amx_method": "AMX 方法:",
"run_int_select_amx": "选择 AMX 方法",
"run_int_step3_title": "第 3 步NUMA 和 CPU 配置",
"run_int_numa_nodes": "NUMA 节点数1-{max}",
"run_int_cpu_threads": "每个 NUMA 的 CPU 线程数1-{max}",
"run_int_amx_warning": "⚠ 警告AMX INT4/INT8 需要兼容的 CPU。检查命令kt doctor",
"run_int_step4_title": "第 4 步GPU 专家配置",
"run_int_gpu_experts": "每层 GPU 专家数0-{max}",
"run_int_gpu_experts_info": "总专家数:{total},每 token 激活:{active}",
"run_int_step5_title": "第 5 步KV Cache 配置",
"run_int_kv_cache_size": "KV Cache 大小tokens",
"run_int_chunk_prefill": "启用分块预填充?",
"run_int_chunk_size": "分块预填充大小tokens",
"run_int_gpu_prefill_threshold": "GPU 预填充阈值tokens",
"run_int_step6_title": "第 6 步GPU 选择和张量并行",
"run_int_available_gpus": "可用 GPU",
"run_int_gpu_id": "GPU {id}",
"run_int_vram_info": "{name}(总计 {total:.1f}GB空闲 {free:.1f}GB",
"run_int_select_gpus": "选择 GPU ID逗号分隔",
"run_int_invalid_gpu_range": "所有 GPU ID 必须在 0 到 {max} 之间",
"run_int_tp_size": "TP 大小(必须是 2 的幂1,2,4,8...",
"run_int_tp_mismatch": "TP 大小必须与选择的 GPU 数量匹配({count}",
"run_int_tp_not_power_of_2": "TP 大小必须是 2 的幂",
"run_int_mem_fraction": "静态内存占用比例0.0-1.0",
"run_int_using_saved_mem": "使用已保存的内存占用比例:{fraction}",
"run_int_step7_title": "第 7 步:解析器配置(可选)",
"run_int_tool_call_parser": "工具调用解析器(按回车跳过)",
"run_int_reasoning_parser": "推理解析器(按回车跳过)",
"run_int_step8_title": "第 8 步:主机和端口配置",
"run_int_host": "主机",
"run_int_port": "端口",
"run_int_port_occupied": "⚠ 端口 {port} 已被占用",
"run_int_port_suggestion": "建议使用可用端口:{port}",
"run_int_use_suggested": "使用建议的端口?",
"run_int_saved_configs": "已保存的配置:",
"run_int_config_name": "配置 {num}",
"run_int_kt_method": "KT 方法:",
"run_int_numa_nodes_label": "NUMA 节点:",
"run_int_cpu_threads_label": "CPU 线程:",
"run_int_gpu_experts_label": "GPU 专家:",
"run_int_tp_size_label": "TP 大小:",
"run_int_mem_fraction_label": "内存占用比例:",
"run_int_server_label": "服务器:",
"run_int_kv_cache_label": "KV Cache",
"run_int_chunk_prefill_label": "分块预填充:",
"run_int_gpu_prefill_label": "GPU 预填充阈值:",
"run_int_tool_parser_label": "工具调用解析器:",
"run_int_reasoning_parser_label": "推理解析器:",
"run_int_command_label": "命令:",
"run_int_select_config": "选择配置",
"run_int_gpu_select_required": "请选择 {tp} 个 GPU来自已保存配置的 TP 大小)",
"run_int_port_check_title": "端口配置",
"run_int_port_checking": "正在检查端口 {port} 可用性...",
"run_int_port_available": "端口 {port} 可用",
"run_int_saved_config_title": "已保存的配置",
"run_int_save_config_title": "保存配置",
"run_int_save_config_prompt": "保存此配置以供将来使用?",
"run_int_config_name_prompt": "配置名称",
"run_int_config_name_default": "配置 {timestamp}",
"run_int_config_saved": "配置已保存:{name}",
"run_int_config_summary": "配置完成",
"run_int_model_label": "模型:",
"run_int_selected_gpus_label": "已选择的 GPU",
# Model command
"model_supported_title": "KTransformers 支持的模型",
"model_column_model": "模型",
@@ -557,6 +1073,180 @@ MESSAGES: dict[str, dict[str, str]] = {
"model_column_name": "名称",
"model_column_hf_repo": "HuggingFace 仓库",
"model_column_aliases": "别名",
# Model management - new user registry system
"model_no_registered_models": "尚未注册任何模型。",
"model_scan_hint": "扫描模型: kt model scan",
"model_add_hint": "添加模型: kt model add /path/to/model",
"model_registered_models_title": "已注册的模型",
"model_column_format": "格式",
"model_column_repo": "仓库",
"model_column_sha256": "SHA256",
"model_non_moe_hidden_hint": "检测到 {count} 个非MoE模型使用 kt model list --all 展示全部",
"model_usage_title": "常用操作:",
"model_usage_info": "查看详情:",
"model_usage_edit": "编辑模型:",
"model_usage_verify": "校验权重:",
"model_usage_quant": "量化模型:",
"model_usage_run": "运行模型:",
"model_usage_scan": "扫描模型:",
"model_usage_add": "添加模型:",
"model_usage_verbose": "查看包含文件详情:",
"model_no_storage_paths": "未配置存储路径。",
"model_add_path_hint": "添加存储路径: kt config set model.storage_paths /path/to/models",
"model_scanning_paths": "正在扫描配置的存储路径...",
"model_scanning_progress": "扫描中: {path}",
"model_scan_warnings_title": "警告",
"model_scan_no_models_found": "在配置的路径中未找到模型。",
"model_scan_check_paths_hint": "检查存储路径: kt config get model.storage_paths",
"model_scan_min_size_hint": "文件夹必须 ≥{size}GB 才能被识别为模型。",
"model_scan_found_title": "发现 {count} 个新模型",
"model_column_path": "路径",
"model_column_size": "大小",
"model_scan_auto_adding": "正在自动添加模型...",
"model_added": "已添加: {name}",
"model_add_failed": "添加 {name} 失败: {error}",
"model_scan_complete": "扫描完成!已添加 {count} 个模型。",
"model_scan_interactive_prompt": "命令: edit <id> | del <id> | done",
"model_scan_cmd_edit": "设置模型自定义名称和仓库",
"model_scan_cmd_delete": "跳过此模型",
"model_scan_cmd_done": "完成并添加模型",
"model_scan_marked_skip": "已跳过模型 #{id}",
"model_scan_invalid_id": "无效的模型 ID: {id}",
"model_scan_invalid_command": "无效命令。使用: edit <id> | del <id> | done",
"model_scan_edit_model": "编辑模型 {id}",
"model_scan_edit_note": "您可以在添加到注册表前更改模型名称和配置仓库信息",
"model_scan_adding_models": "正在添加 {count} 个模型...",
"model_scan_next_steps": "后续步骤",
"model_scan_view_hint": "查看已注册模型: kt model list",
"model_scan_edit_hint": "编辑模型详情: kt model edit <name>",
"model_scan_no_models_added": "未添加任何模型。",
"model_add_path_not_exist": "错误: 路径不存在: {path}",
"model_add_not_directory": "错误: 路径不是目录: {path}",
"model_add_already_registered": "此路径已注册为: {name}",
"model_add_view_hint": "查看: kt model info {name}",
"model_add_scanning": "正在扫描模型文件...",
"model_add_scan_failed": "扫描模型失败: {error}",
"model_add_no_model_files": "{path} 中未找到模型文件",
"model_add_supported_formats": "支持: *.safetensors, *.gguf (文件夹 ≥10GB)",
"model_add_detected": "检测到: {format} 格式, {size}, {count} 个文件",
"model_add_name_conflict": "名称 '{name}' 已存在。",
"model_add_prompt_name": "为此模型输入名称",
"model_add_name_exists": "名称已存在。请选择其他名称:",
"model_add_configure_repo": "配置仓库信息以进行 SHA256 验证?",
"model_add_repo_type_prompt": "选择仓库类型:",
"model_add_choice": "选择",
"model_add_repo_id_prompt": "输入仓库 ID (例如: deepseek-ai/DeepSeek-V3)",
"model_add_success": "成功添加模型: {name}",
"model_add_verify_hint": "验证完整性: kt model verify {name}",
"model_add_edit_later_hint": "稍后编辑详情: kt model edit {name}",
"model_add_failed_generic": "添加模型失败: {error}",
"model_edit_not_found": "未找到模型 '{name}'",
"model_edit_list_hint": "列出模型: kt model list",
"model_edit_current_config": "当前配置",
"model_edit_what_to_edit": "您想编辑什么?",
"model_edit_option_name": "编辑名称",
"model_edit_option_repo": "配置仓库信息",
"model_edit_option_delete": "删除此模型",
"model_edit_option_cancel": "取消 / 退出",
"model_edit_choice_prompt": "选择选项",
"model_edit_new_name": "输入新名称",
"model_edit_name_conflict": "名称 '{name}' 已存在。请选择其他名称:",
"model_edit_name_updated": "名称已更新: {old}{new}",
"model_edit_repo_type_prompt": "仓库类型 (或按回车删除仓库信息):",
"model_edit_repo_remove": "删除仓库信息",
"model_edit_repo_id_prompt": "输入仓库 ID",
"model_edit_repo_removed": "仓库信息已删除",
"model_edit_repo_updated": "仓库已配置: {repo_type}{repo_id}",
"model_edit_delete_warning": "从注册表中删除模型 '{name}'?",
"model_edit_delete_note": "注意: 这只会删除注册表条目。{path} 中的模型文件不会被删除。",
"model_edit_delete_confirm": "确认删除?",
"model_edit_deleted": "模型 '{name}' 已从注册表中删除",
"model_edit_delete_cancelled": "删除已取消",
"model_edit_cancelled": "编辑已取消",
# Model edit - Interactive selection
"model_edit_select_title": "选择要编辑的模型",
"model_edit_select_model": "选择模型",
"model_edit_invalid_choice": "无效选择",
"model_edit_no_models": "注册表中未找到模型。",
"model_edit_add_hint_scan": "添加模型:",
"model_edit_add_hint_add": "或:",
# Model edit - Display
"model_edit_gpu_links": "GPU 链接:",
# Model edit - Menu options
"model_edit_manage_gpu_links": "管理 GPU 链接",
"model_edit_save_changes": "保存更改",
"model_edit_has_changes": "(有更改)",
"model_edit_no_changes": "(无更改)",
# Model edit - Pending changes messages
"model_edit_name_pending": "名称将在保存更改时更新。",
"model_edit_repo_remove_pending": "仓库信息将在保存更改时删除。",
"model_edit_repo_update_pending": "仓库信息将在保存更改时更新。",
# Model edit - GPU link management
"model_edit_gpu_links_title": "管理 {name} 的 GPU 链接",
"model_edit_current_gpu_links": "当前 GPU 链接:",
"model_edit_no_gpu_links": "未配置 GPU 链接。",
"model_edit_gpu_options": "选项:",
"model_edit_gpu_add": "添加 GPU 链接",
"model_edit_gpu_remove": "删除 GPU 链接",
"model_edit_gpu_clear": "清除所有 GPU 链接",
"model_edit_gpu_back": "返回主菜单",
"model_edit_gpu_choose_option": "选择选项",
"model_edit_gpu_none_available": "没有可链接的 GPU 模型。",
"model_edit_gpu_available_models": "可用的 GPU 模型:",
"model_edit_gpu_already_linked": "(已链接)",
"model_edit_gpu_enter_number": "输入要添加的 GPU 模型编号",
"model_edit_gpu_link_pending": "GPU 链接将在保存更改时添加: {name}",
"model_edit_gpu_already_exists": "此 GPU 模型已链接。",
"model_edit_gpu_invalid_choice": "无效选择。",
"model_edit_gpu_invalid_input": "无效输入。",
"model_edit_gpu_none_to_remove": "没有可删除的 GPU 链接。",
"model_edit_gpu_choose_to_remove": "选择要删除的 GPU 链接:",
"model_edit_gpu_enter_to_remove": "输入要删除的编号",
"model_edit_gpu_remove_pending": "GPU 链接将在保存更改时删除: {name}",
"model_edit_gpu_none_to_clear": "没有可清除的 GPU 链接。",
"model_edit_gpu_clear_confirm": "删除所有 GPU 链接?",
"model_edit_gpu_clear_pending": "所有 GPU 链接将在保存更改时删除。",
"model_edit_cancelled_short": "已取消。",
# Model edit - Save operation
"model_edit_no_changes_to_save": "没有更改可保存。",
"model_edit_saving": "正在保存更改...",
"model_edit_saved": "更改保存成功!",
"model_edit_updated_config": "更新后的配置:",
"model_edit_repo_changed_warning": "⚠ 仓库信息已更改。",
"model_edit_verify_hint": "运行 [cyan]kt model verify[/cyan] 以使用 SHA256 校验和验证模型完整性。",
"model_edit_discard_changes": "放弃未保存的更改?",
"model_info_not_found": "未找到模型 '{name}'",
"model_info_list_hint": "列出所有模型: kt model list",
"model_remove_not_found": "未找到模型 '{name}'",
"model_remove_list_hint": "列出模型: kt model list",
"model_remove_warning": "从注册表中删除模型 '{name}'?",
"model_remove_note": "注意: 这只会删除注册表条目。模型文件不会从 {path} 中删除。",
"model_remove_confirm": "确认删除?",
"model_remove_cancelled": "删除已取消",
"model_removed": "模型 '{name}' 已从注册表中删除",
"model_remove_failed": "删除模型失败: {error}",
"model_refresh_checking": "正在检查模型路径...",
"model_refresh_all_valid": "所有模型都有效! (已检查 {count} 个模型)",
"model_refresh_total": "总模型数: {total}",
"model_refresh_missing_found": "发现 {count} 个缺失的模型",
"model_refresh_suggestions": "建议操作",
"model_refresh_remove_hint": "从注册表中删除: kt model remove <name>",
"model_refresh_rescan_hint": "重新扫描模型: kt model scan",
"model_verify_not_found": "未找到模型 '{name}'",
"model_verify_list_hint": "列出模型: kt model list",
"model_verify_no_repo": "模型 '{name}' 未配置仓库信息。",
"model_verify_config_hint": "配置仓库: kt model edit {name}",
"model_verify_path_missing": "模型路径不存在: {path}",
"model_verify_starting": "正在验证模型完整性...",
"model_verify_progress": "仓库: {repo_type}{repo_id}",
"model_verify_not_implemented": "SHA256 验证尚未实现",
"model_verify_future_note": "此功能将从 {repo_type} 获取官方 SHA256 哈希值并与本地文件进行比较。",
"model_verify_passed": "验证通过!所有文件都与官方哈希匹配。",
"model_verify_failed": "验证失败!{count} 个文件的哈希不匹配。",
"model_verify_all_no_repos": "没有模型配置了仓库信息。",
"model_verify_all_config_hint": "配置仓库使用: kt model edit <name>",
"model_verify_all_found": "发现 {count} 个配置了仓库信息的模型",
"model_verify_all_manual_hint": "验证特定模型: kt model verify <name>",
# Coming soon
"feature_coming_soon": "此功能即将推出...",
},

View File

@@ -5,6 +5,10 @@ KTransformers CLI - A unified command-line interface for KTransformers.
"""
import sys
import warnings
# Suppress numpy subnormal warnings
warnings.filterwarnings("ignore", message="The value of the smallest subnormal")
import typer
@@ -28,6 +32,7 @@ def _get_help(key: str) -> str:
"run": {"en": "Start model inference server", "zh": "启动模型推理服务器"},
"chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"},
"quant": {"en": "Quantize model weights", "zh": "量化模型权重"},
"edit": {"en": "Edit model information", "zh": "编辑模型信息"},
"bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"},
"microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"},
"doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"},
@@ -43,7 +48,7 @@ def _get_help(key: str) -> str:
app = typer.Typer(
name="kt",
help="KTransformers CLI - A unified command-line interface for KTransformers.",
no_args_is_help=True,
no_args_is_help=False, # Handle no-args case manually to support first-run setup
add_completion=False, # Use static completion scripts instead of dynamic completion
rich_markup_mode="rich",
)
@@ -66,20 +71,7 @@ def _update_help_texts() -> None:
group_info.help = _get_help(group_info.name)
# Register commands
app.command(name="version", help="Show version information")(version.version)
# Run command is handled specially in main() to allow extra args
# (not registered here to avoid typer's argument parsing)
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
app.command(name="quant", help="Quantize model weights")(quant.quant)
app.command(name="bench", help="Run full benchmark")(bench.bench)
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
# Register sub-apps
app.add_typer(model.app, name="model", help="Manage models and storage paths")
app.add_typer(config.app, name="config", help="Manage configuration")
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
# Commands are registered later after tui_command is defined
def check_first_run() -> None:
@@ -116,7 +108,7 @@ def _show_first_run_setup(settings) -> None:
from rich.spinner import Spinner
from rich.live import Live
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb
console = Console()
@@ -140,15 +132,8 @@ def _show_first_run_setup(settings) -> None:
console.print(" [cyan][2][/cyan] 中文 (Chinese)")
console.print()
while True:
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
if choice == "1":
lang = "en"
break
elif choice == "2":
lang = "zh"
break
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
lang = "en" if choice == "1" else "zh"
# Save language setting
settings.set("general.language", lang)
@@ -161,6 +146,131 @@ def _show_first_run_setup(settings) -> None:
else:
console.print("[green]✓[/green] Language set to English")
# Model discovery section
console.print()
if lang == "zh":
console.print("[bold]发现模型权重[/bold]")
console.print()
console.print("[dim]扫描系统中已有的模型权重文件,以便快速添加到模型列表。[/dim]")
console.print()
console.print(" [cyan][1][/cyan] 全局扫描 (自动扫描所有非系统路径)")
console.print(" [cyan][2][/cyan] 手动指定路径 (可添加多个)")
console.print(" [cyan][3][/cyan] 跳过 (稍后手动添加)")
console.print()
scan_choice = Prompt.ask("选择扫描方式", choices=["1", "2", "3"], default="1")
else:
console.print("[bold]Discover Model Weights[/bold]")
console.print()
console.print("[dim]Scan existing model weights on your system to quickly add them to the model list.[/dim]")
console.print()
console.print(" [cyan][1][/cyan] Global scan (auto-scan all non-system paths)")
console.print(" [cyan][2][/cyan] Manual paths (add multiple paths)")
console.print(" [cyan][3][/cyan] Skip (add manually later)")
console.print()
scan_choice = Prompt.ask("Select scan method", choices=["1", "2", "3"], default="1")
if scan_choice == "1":
# Global scan
from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary
console.print()
try:
total_found, new_found, registered = discover_and_register_global(
min_size_gb=2.0, max_depth=6, show_progress=True, lang=lang
)
format_discovery_summary(
total_found=total_found,
new_found=new_found,
registered=registered,
lang=lang,
show_models=True,
max_show=10,
)
except Exception as e:
console.print(f"[yellow]Warning: Scan failed - {e}[/yellow]")
elif scan_choice == "2":
# Manual path specification
from kt_kernel.cli.utils.model_discovery import discover_and_register_path
import os
discovered_paths = set() # Track paths discovered in this session
total_registered = []
while True:
console.print()
if lang == "zh":
path = Prompt.ask("输入要扫描的路径 (例如: /mnt/data/models)")
else:
path = Prompt.ask("Enter path to scan (e.g., /mnt/data/models)")
# Expand and validate path
path = os.path.expanduser(path)
if not os.path.exists(path):
if lang == "zh":
console.print(f"[yellow]警告: 路径不存在: {path}[/yellow]")
else:
console.print(f"[yellow]Warning: Path does not exist: {path}[/yellow]")
continue
if not os.path.isdir(path):
if lang == "zh":
console.print(f"[yellow]警告: 不是一个目录: {path}[/yellow]")
else:
console.print(f"[yellow]Warning: Not a directory: {path}[/yellow]")
continue
# Scan this path
console.print()
try:
total_found, new_found, registered = discover_and_register_path(
path=path, min_size_gb=2.0, existing_paths=discovered_paths, show_progress=True, lang=lang
)
# Update discovered paths
for model in registered:
discovered_paths.add(model.path)
total_registered.extend(registered)
console.print()
if lang == "zh":
console.print(f"[green]✓[/green] 在此路径找到 {total_found} 个模型,其中 {new_found} 个为新模型")
else:
console.print(f"[green]✓[/green] Found {total_found} models in this path, {new_found} are new")
if new_found > 0:
for model in registered[:5]:
console.print(f"{model.name} ({model.format})")
if len(registered) > 5:
if lang == "zh":
console.print(f" [dim]... 还有 {len(registered) - 5} 个新模型[/dim]")
else:
console.print(f" [dim]... and {len(registered) - 5} more new models[/dim]")
except Exception as e:
console.print(f"[red]Error scanning path: {e}[/red]")
# Ask if continue
console.print()
if lang == "zh":
continue_scan = Confirm.ask("是否继续添加其他路径?", default=False)
else:
continue_scan = Confirm.ask("Continue adding more paths?", default=False)
if not continue_scan:
break
if total_registered:
console.print()
if lang == "zh":
console.print(f"[green]✓[/green] 总共发现 {len(total_registered)} 个新模型")
else:
console.print(f"[green]✓[/green] Total {len(total_registered)} new models discovered")
# Model storage path selection
console.print()
console.print(f"[bold]{t('setup_model_path_title')}[/bold]")
@@ -174,16 +284,7 @@ def _show_first_run_setup(settings) -> None:
console.print()
if locations:
# Scan for models in each location
console.print(f"[dim]{t('setup_scanning_models')}[/dim]")
location_models: dict[str, list] = {}
for loc in locations[:5]:
models = scan_models_in_location(loc, max_depth=2)
if models:
location_models[loc.path] = models
console.print()
# Show options
# Show storage location options
for i, loc in enumerate(locations[:5], 1): # Show top 5 options
available = format_size_gb(loc.available_gb)
total = format_size_gb(loc.total_gb)
@@ -194,22 +295,8 @@ def _show_first_run_setup(settings) -> None:
else:
option_str = t("setup_disk_option", path=loc.path, available=available, total=total)
# Add model count if any
if loc.path in location_models:
model_count = len(location_models[loc.path])
option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]"
console.print(f" [cyan][{i}][/cyan] {option_str}")
# Show first few models found in this location
if loc.path in location_models:
for model in location_models[loc.path][:3]: # Show up to 3 models
size_str = format_size_gb(model.size_gb)
console.print(f" [dim]• {model.name} ({size_str})[/dim]")
if len(location_models[loc.path]) > 3:
remaining = len(location_models[loc.path]) - 3
console.print(f" [dim] ... +{remaining} more[/dim]")
# Custom path option
custom_idx = min(len(locations), 5) + 1
console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}")
@@ -323,51 +410,28 @@ def _install_shell_completion() -> None:
# Detect current shell
shell = os.environ.get("SHELL", "")
if "zsh" in shell:
shell_name = "zsh"
elif "fish" in shell:
shell_name = "fish"
else:
shell_name = "bash"
shell_name = "zsh" if "zsh" in shell else "fish" if "fish" in shell else "bash"
try:
cli_dir = Path(__file__).parent
completions_dir = cli_dir / "completions"
home = Path.home()
installed = False
def install_completion(src_name: str, dest_dir: Path, dest_name: str) -> None:
"""Install completion file from source to destination."""
src_file = completions_dir / src_name
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_dir / dest_name)
if shell_name == "bash":
# Use XDG standard location for bash-completion (auto-loaded)
src_file = completions_dir / "kt-completion.bash"
dest_dir = home / ".local" / "share" / "bash-completion" / "completions"
dest_file = dest_dir / "kt"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
install_completion(
"kt-completion.bash", home / ".local" / "share" / "bash-completion" / "completions", "kt"
)
elif shell_name == "zsh":
src_file = completions_dir / "_kt"
dest_dir = home / ".zfunc"
dest_file = dest_dir / "_kt"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
install_completion("_kt", home / ".zfunc", "_kt")
elif shell_name == "fish":
# Fish auto-loads from this directory
src_file = completions_dir / "kt.fish"
dest_dir = home / ".config" / "fish" / "completions"
dest_file = dest_dir / "kt.fish"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
install_completion("kt.fish", home / ".config" / "fish" / "completions", "kt.fish")
# Mark as installed
settings.set("general._completion_installed", True)
@@ -403,6 +467,20 @@ def _apply_saved_language() -> None:
set_lang(lang)
app.command(name="version", help="Show version information")(version.version)
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
app.command(name="quant", help="Quantize model weights")(quant.quant)
app.command(name="edit", help="Edit model information")(model.edit_model)
app.command(name="bench", help="Run full benchmark")(bench.bench)
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
# Register sub-apps
app.add_typer(model.app, name="model", help="Manage models and storage paths")
app.add_typer(config.app, name="config", help="Manage configuration")
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
def main():
"""Main entry point."""
# Apply saved language setting first (before anything else for correct help display)
@@ -414,7 +492,7 @@ def main():
# Check for first run (but not for certain commands)
# Skip first-run check for: --help, config commands, version
args = sys.argv[1:] if len(sys.argv) > 1 else []
skip_commands = ["--help", "-h", "config", "version", "--version"]
skip_commands = ["--help", "-h", "config", "version", "--version", "--no-tui"]
should_check_first_run = True
for arg in args:
@@ -422,12 +500,35 @@ def main():
should_check_first_run = False
break
# Handle no arguments case
if not args:
# Check if this is first run
from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE, get_settings
is_first_run = False
if not DEFAULT_CONFIG_FILE.exists():
is_first_run = True
else:
settings = get_settings()
if not settings.get("general._initialized"):
is_first_run = True
if is_first_run:
# First run - start initialization
_install_shell_completion()
check_first_run()
return
else:
# Not first run - show help
app(["--help"])
return
# Auto-install shell completion on first run
if should_check_first_run:
_install_shell_completion()
# Check first run before running commands
if should_check_first_run and args:
if should_check_first_run:
check_first_run()
# Handle "run" command specially to pass through unknown options

View File

@@ -0,0 +1,413 @@
#!/usr/bin/env python3
"""
快速分析 MoE 模型 - 基于 config.json
(复用 sglang 的模型注册表和判断逻辑)
"""
import json
import hashlib
from pathlib import Path
from typing import Optional, Dict, Any
def _get_sglang_moe_architectures():
"""
从 sglang 的模型注册表获取所有 MoE 架构
复用 sglang 的代码,这样 sglang 更新后自动支持新模型
"""
try:
import sys
# 添加 sglang 路径到 sys.path
sglang_path = Path("/mnt/data2/ljq/sglang/python")
if sglang_path.exists() and str(sglang_path) not in sys.path:
sys.path.insert(0, str(sglang_path))
# 直接导入 sglang 的 ModelRegistry
# 注意:这需要 sglang 及其依赖正确安装
from sglang.srt.models.registry import ModelRegistry
# 获取所有支持的架构
supported_archs = ModelRegistry.get_supported_archs()
# 过滤出 MoE 模型(名称包含 Moe
moe_archs = {arch for arch in supported_archs if "Moe" in arch or "moe" in arch.lower()}
# 手动添加一些不带 "Moe" 字样但是 MoE 模型的架构
# DeepSeek V2/V3 系列
deepseek_moe = {arch for arch in supported_archs if arch.startswith("Deepseek") or arch.startswith("deepseek")}
moe_archs.update(deepseek_moe)
# DBRX 也是 MoE 模型
dbrx_moe = {arch for arch in supported_archs if "DBRX" in arch or "dbrx" in arch.lower()}
moe_archs.update(dbrx_moe)
# Grok 也是 MoE 模型
grok_moe = {arch for arch in supported_archs if "Grok" in arch or "grok" in arch.lower()}
moe_archs.update(grok_moe)
return moe_archs
except Exception as e:
# 如果 sglang 不可用,返回空集合
# 这种情况下,后续会使用配置文件中的其他判断方法
import warnings
warnings.warn(f"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.")
return set()
# 获取 MoE 架构列表(优先从 sglang 获取)
MOE_ARCHITECTURES = _get_sglang_moe_architectures()
def _get_cache_file():
"""获取集中式缓存文件路径"""
cache_dir = Path.home() / ".ktransformers" / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "moe_analysis_v2.json"
def _load_all_cache():
"""加载所有缓存数据"""
cache_file = _get_cache_file()
if not cache_file.exists():
return {}
try:
with open(cache_file, "r") as f:
return json.load(f)
except Exception:
return {}
def _save_all_cache(cache_data):
"""保存所有缓存数据"""
cache_file = _get_cache_file()
try:
with open(cache_file, "w") as f:
json.dump(cache_data, f, indent=2)
except Exception as e:
import warnings
warnings.warn(f"Failed to save MoE cache: {e}")
def _compute_config_fingerprint(config_path: Path) -> Optional[str]:
"""计算 config.json 指纹"""
if not config_path.exists():
return None
try:
stat = config_path.stat()
# 使用文件大小和修改时间作为指纹
fingerprint_str = f"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}"
return hashlib.md5(fingerprint_str.encode()).hexdigest()
except Exception:
return None
def _load_cache(model_path: Path) -> Optional[Dict[str, Any]]:
"""加载指定模型的缓存"""
model_path_str = str(model_path.resolve())
all_cache = _load_all_cache()
if model_path_str not in all_cache:
return None
try:
cache_entry = all_cache[model_path_str]
# 验证缓存版本
cache_version = cache_entry.get("cache_version", 0)
if cache_version != 2:
return None
# 验证 config.json 指纹
config_path = model_path / "config.json"
current_fingerprint = _compute_config_fingerprint(config_path)
if cache_entry.get("fingerprint") != current_fingerprint:
return None
return cache_entry.get("result")
except Exception:
return None
def _save_cache(model_path: Path, result: Dict[str, Any]):
"""保存指定模型的缓存"""
model_path_str = str(model_path.resolve())
try:
config_path = model_path / "config.json"
fingerprint = _compute_config_fingerprint(config_path)
all_cache = _load_all_cache()
all_cache[model_path_str] = {
"fingerprint": fingerprint,
"result": result,
"cache_version": 2,
"last_updated": __import__("datetime").datetime.now().isoformat(),
}
_save_all_cache(all_cache)
except Exception as e:
import warnings
warnings.warn(f"Failed to save MoE cache for {model_path}: {e}")
def _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]:
"""读取 config.json 文件
参考 sglang 的 get_config() 实现
"""
config_path = model_path / "config.json"
if not config_path.exists():
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config
except Exception:
return None
def _is_moe_model(config: Dict[str, Any]) -> bool:
"""判断是否是 MoE 模型
参考 sglang 的模型注册表和架构识别方式
"""
# 方法1: 检查架构名称
architectures = config.get("architectures", [])
if any(arch in MOE_ARCHITECTURES for arch in architectures):
return True
# 方法2: 检查是否有 MoE 相关字段Mistral 格式)
if config.get("moe"):
return True
# 方法3: 检查是否有 num_experts 或其变体字段
# 需要检查 text_config对于某些多模态模型
text_config = config.get("text_config", config)
# 检查各种专家数量字段
if (
text_config.get("num_experts") or text_config.get("num_local_experts") or text_config.get("n_routed_experts")
): # Kimi-K2 使用这个字段
return True
return False
def _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]:
"""从 config 中提取 MoE 参数
参考 sglang 的各种 MoE 模型实现
"""
# 处理嵌套的 text_config
text_config = config.get("text_config", config)
# 提取基本参数
result = {
"architectures": config.get("architectures", []),
"model_type": config.get("model_type", "unknown"),
}
# 专家数量(不同模型字段名不同)
num_experts = (
text_config.get("num_experts") # Qwen2/3 MoE, DeepSeek V2
or text_config.get("num_local_experts") # Mixtral
or text_config.get("n_routed_experts") # Kimi-K2, DeepSeek V3
or config.get("moe", {}).get("num_experts") # Mistral 格式
)
# 每个 token 激活的专家数
num_experts_per_tok = (
text_config.get("num_experts_per_tok")
or text_config.get("num_experts_per_token")
or config.get("moe", {}).get("num_experts_per_tok")
or 2 # 默认值
)
# 层数
num_hidden_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or 0
# 隐藏层维度
hidden_size = text_config.get("hidden_size") or text_config.get("d_model") or 0
# MoE 专家中间层大小
moe_intermediate_size = (
text_config.get("moe_intermediate_size")
or text_config.get("intermediate_size") # 如果没有特殊的 moe_intermediate_size
or 0
)
# 共享专家中间层大小Qwen2/3 MoE
shared_expert_intermediate_size = text_config.get("shared_expert_intermediate_size", 0)
result.update(
{
"num_experts": num_experts or 0,
"num_experts_per_tok": num_experts_per_tok,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
"moe_intermediate_size": moe_intermediate_size,
"shared_expert_intermediate_size": shared_expert_intermediate_size,
}
)
# 提取其他有用的参数
result["num_attention_heads"] = text_config.get("num_attention_heads", 0)
result["num_key_value_heads"] = text_config.get("num_key_value_heads", 0)
result["vocab_size"] = text_config.get("vocab_size", 0)
result["max_position_embeddings"] = text_config.get("max_position_embeddings", 0)
return result
def _estimate_model_size(model_path: Path) -> float:
"""估算模型总大小GB
快速统计 safetensors 文件总大小
"""
try:
total_size = 0
for file_path in model_path.glob("*.safetensors"):
total_size += file_path.stat().st_size
return total_size / (1024**3)
except Exception:
return 0.0
def analyze_moe_model(model_path, use_cache=True):
"""
快速分析 MoE 模型 - 只读取 config.json
参数:
model_path: 模型路径字符串或Path对象
use_cache: 是否使用缓存默认True
返回:
dict: {
'is_moe': 是否是 MoE 模型,
'num_experts': 专家总数,
'num_experts_per_tok': 每个 token 激活的专家数,
'num_hidden_layers': 层数,
'hidden_size': 隐藏层维度,
'moe_intermediate_size': MoE 专家中间层大小,
'shared_expert_intermediate_size': 共享专家中间层大小,
'architectures': 模型架构列表,
'model_type': 模型类型,
'total_size_gb': 模型总大小估算GB,
'cached': 是否从缓存读取
}
如果不是 MoE 模型或失败,返回 None
"""
model_path = Path(model_path)
if not model_path.exists():
return None
# 尝试加载缓存
if use_cache:
cached_result = _load_cache(model_path)
if cached_result:
cached_result["cached"] = True
return cached_result
# 读取 config.json
config = _load_config_json(model_path)
if not config:
return None
# 判断是否是 MoE 模型
if not _is_moe_model(config):
return None
# 提取 MoE 参数
params = _extract_moe_params(config)
# 验证必要参数
if params["num_experts"] == 0:
return None
# 估算模型大小
total_size_gb = _estimate_model_size(model_path)
# 组装结果
result = {
"is_moe": True,
"num_experts": params["num_experts"],
"num_experts_per_tok": params["num_experts_per_tok"],
"num_hidden_layers": params["num_hidden_layers"],
"hidden_size": params["hidden_size"],
"moe_intermediate_size": params["moe_intermediate_size"],
"shared_expert_intermediate_size": params["shared_expert_intermediate_size"],
"architectures": params["architectures"],
"model_type": params["model_type"],
"total_size_gb": total_size_gb,
"cached": False,
# 额外参数
"num_attention_heads": params.get("num_attention_heads", 0),
"num_key_value_heads": params.get("num_key_value_heads", 0),
"vocab_size": params.get("vocab_size", 0),
}
# 保存缓存
if use_cache:
_save_cache(model_path, result)
return result
def print_analysis(model_path):
"""打印模型分析结果"""
print(f"分析模型: {model_path}\n")
result = analyze_moe_model(model_path)
if result is None:
print("不是 MoE 模型或分析失败")
return
print("=" * 70)
print("MoE 模型分析结果")
if result.get("cached"):
print("[使用缓存]")
print("=" * 70)
print(f"模型架构:")
print(f" - 架构: {', '.join(result['architectures'])}")
print(f" - 类型: {result['model_type']}")
print()
print(f"MoE 结构:")
print(f" - 专家总数: {result['num_experts']}")
print(f" - 激活专家数: {result['num_experts_per_tok']} experts/token")
print(f" - 层数: {result['num_hidden_layers']}")
print(f" - 隐藏维度: {result['hidden_size']}")
print(f" - MoE 中间层: {result['moe_intermediate_size']}")
if result["shared_expert_intermediate_size"] > 0:
print(f" - 共享专家中间层: {result['shared_expert_intermediate_size']}")
print()
print(f"大小统计:")
print(f" - 模型总大小: {result['total_size_gb']:.2f} GB")
print("=" * 70)
print()
def main():
import sys
models = ["/mnt/data2/models/Qwen3-30B-A3B", "/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507"]
if len(sys.argv) > 1:
models = [sys.argv[1]]
for model_path in models:
print_analysis(model_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,118 @@
"""
Debug utility to inspect saved run configurations.
Usage: python -m kt_kernel.cli.utils.debug_configs
"""
from pathlib import Path
import yaml
from rich.console import Console
from rich.table import Table
from rich import box
console = Console()
def main():
"""Show all saved configurations."""
config_file = Path.home() / ".ktransformers" / "run_configs.yaml"
console.print()
console.print(f"[bold]Configuration file:[/bold] {config_file}")
console.print()
if not config_file.exists():
console.print("[red]✗ Configuration file does not exist![/red]")
console.print()
console.print("No configurations have been saved yet.")
return
try:
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
except Exception as e:
console.print(f"[red]✗ Failed to load configuration file: {e}[/red]")
return
console.print(f"[green]✓[/green] Configuration file loaded")
console.print()
configs = data.get("configs", {})
if not configs:
console.print("[yellow]No saved configurations found.[/yellow]")
return
console.print(f"[bold]Found configurations for {len(configs)} model(s):[/bold]")
console.print()
for model_id, model_configs in configs.items():
console.print(f"[cyan]Model ID:[/cyan] {model_id}")
console.print(f"[dim] {len(model_configs)} configuration(s)[/dim]")
console.print()
if not model_configs:
continue
# Display configs in a table
table = Table(box=box.ROUNDED, show_header=True, header_style="bold cyan")
table.add_column("#", justify="right", style="cyan")
table.add_column("Name", style="white")
table.add_column("Method", style="yellow")
table.add_column("TP", justify="right", style="green")
table.add_column("GPU Experts", justify="right", style="magenta")
table.add_column("Created", style="dim")
for i, cfg in enumerate(model_configs, 1):
method = cfg.get("inference_method", "?")
kt_method = cfg.get("kt_method", "?")
method_display = f"{method.upper()}"
if method == "raw":
method_display += f" ({cfg.get('raw_method', '?')})"
elif method == "amx":
method_display += f" ({kt_method})"
table.add_row(
str(i),
cfg.get("config_name", f"Config {i}"),
method_display,
str(cfg.get("tp_size", "?")),
str(cfg.get("gpu_experts", "?")),
cfg.get("created_at", "Unknown")[:19] if cfg.get("created_at") else "Unknown",
)
console.print(table)
console.print()
# Also check user_models.yaml to show model names
console.print("[bold]Checking model registry...[/bold]")
console.print()
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
try:
registry = UserModelRegistry()
all_models = registry.list_models()
console.print(f"[green]✓[/green] Found {len(all_models)} registered model(s)")
console.print()
# Map model IDs to names
id_to_name = {m.id: m.name for m in all_models}
console.print("[bold]Model ID → Name mapping:[/bold]")
console.print()
for model_id in configs.keys():
model_name = id_to_name.get(model_id, "[red]Unknown (model not found in registry)[/red]")
console.print(f" {model_id[:8]}... → {model_name}")
console.print()
except Exception as e:
console.print(f"[yellow]⚠ Could not load model registry: {e}[/yellow]")
console.print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,146 @@
"""Helper functions for interactive model download."""
from pathlib import Path
from typing import Dict, List, Tuple
import fnmatch
def list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]:
"""
List files in a HuggingFace repository.
Returns:
List of dicts with keys: 'path', 'size' (in bytes)
"""
from huggingface_hub import HfApi
import os
# Set mirror if needed
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
try:
api = HfApi()
files_info = api.list_repo_tree(repo_id=repo_id, recursive=True)
result = []
for item in files_info:
# Skip directories
if hasattr(item, "type") and item.type == "directory":
continue
# Get file info
file_path = item.path if hasattr(item, "path") else str(item)
file_size = item.size if hasattr(item, "size") else 0
result.append({"path": file_path, "size": file_size})
return result
finally:
# Restore original endpoint
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
def list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]:
"""
List files in a ModelScope repository.
Returns:
List of dicts with keys: 'path', 'size' (in bytes)
"""
from modelscope.hub.api import HubApi
api = HubApi()
files_info = api.get_model_files(model_id=repo_id, recursive=True)
result = []
for file_info in files_info:
file_path = file_info.get("Name", file_info.get("Path", ""))
file_size = file_info.get("Size", 0)
result.append({"path": file_path, "size": file_size})
return result
def filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]:
"""Filter files by glob pattern."""
if pattern == "*":
return files
filtered = []
for file in files:
# Check if filename matches pattern
filename = Path(file["path"]).name
full_path = file["path"]
if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern):
filtered.append(file)
return filtered
def calculate_total_size(files: List[Dict[str, any]]) -> int:
"""Calculate total size of files in bytes."""
return sum(f["size"] for f in files)
def format_file_list_table(files: List[Dict[str, any]], max_display: int = 10):
"""Format file list as a table for display."""
from rich.table import Table
from kt_kernel.cli.utils.model_scanner import format_size
table = Table(show_header=True, header_style="bold")
table.add_column("File", style="cyan", overflow="fold")
table.add_column("Size", justify="right")
# Show first max_display files
for file in files[:max_display]:
table.add_row(file["path"], format_size(file["size"]))
if len(files) > max_display:
table.add_row(f"... and {len(files) - max_display} more files", "[dim]...[/dim]")
return table
def verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]:
"""
Verify if a repository exists.
Returns:
(exists: bool, message: str)
"""
try:
if repo_type == "huggingface":
import os
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from huggingface_hub import HfApi
try:
api = HfApi()
api.repo_info(repo_id=repo_id, repo_type="model")
return True, "Repository found"
finally:
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
else: # modelscope
from modelscope.hub.api import HubApi
api = HubApi()
api.get_model(model_id=repo_id)
return True, "Repository found"
except Exception as e:
return False, f"Repository not found: {str(e)}"

View File

@@ -0,0 +1,216 @@
"""
Input validation utilities with retry mechanism.
Provides robust input validation with automatic retry on failure.
"""
from typing import Optional, List, Callable, Any
from rich.console import Console
from rich.prompt import Prompt
console = Console()
def prompt_int_with_retry(
message: str,
default: Optional[int] = None,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
validator: Optional[Callable[[int], bool]] = None,
validator_error_msg: Optional[str] = None,
) -> int:
"""Prompt for integer input with validation and retry.
Args:
message: Prompt message
default: Default value (optional)
min_val: Minimum allowed value (optional)
max_val: Maximum allowed value (optional)
validator: Custom validation function (optional)
validator_error_msg: Error message for custom validator (optional)
Returns:
Validated integer value
"""
while True:
# Build prompt with default
if default is not None:
prompt_text = f"{message} [{default}]"
else:
prompt_text = message
# Get input
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
# Try to parse as integer
try:
value = int(user_input)
except ValueError:
console.print(f"[red]✗ Invalid input. Please enter a valid integer.[/red]")
console.print()
continue
# Validate range
if min_val is not None and value < min_val:
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
console.print()
continue
if max_val is not None and value > max_val:
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
console.print()
continue
# Custom validation
if validator is not None:
if not validator(value):
error_msg = validator_error_msg or "Invalid value"
console.print(f"[red]✗ {error_msg}[/red]")
console.print()
continue
# All validations passed
return value
def prompt_float_with_retry(
message: str,
default: Optional[float] = None,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
) -> float:
"""Prompt for float input with validation and retry.
Args:
message: Prompt message
default: Default value (optional)
min_val: Minimum allowed value (optional)
max_val: Maximum allowed value (optional)
Returns:
Validated float value
"""
while True:
# Build prompt with default
if default is not None:
prompt_text = f"{message} [{default}]"
else:
prompt_text = message
# Get input
user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)
# Try to parse as float
try:
value = float(user_input)
except ValueError:
console.print(f"[red]✗ Invalid input. Please enter a valid number.[/red]")
console.print()
continue
# Validate range
if min_val is not None and value < min_val:
console.print(f"[red]✗ Value must be at least {min_val}[/red]")
console.print()
continue
if max_val is not None and value > max_val:
console.print(f"[red]✗ Value must be at most {max_val}[/red]")
console.print()
continue
# All validations passed
return value
def prompt_choice_with_retry(
message: str,
choices: List[str],
default: Optional[str] = None,
) -> str:
"""Prompt for choice input with validation and retry.
Args:
message: Prompt message
choices: List of valid choices
default: Default choice (optional)
Returns:
Selected choice
"""
while True:
# Get input
user_input = Prompt.ask(message, default=default)
# Validate choice
if user_input not in choices:
console.print(f"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]")
console.print()
continue
return user_input
def prompt_int_list_with_retry(
message: str,
default: Optional[str] = None,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None,
) -> List[int]:
"""Prompt for comma-separated integer list with validation and retry.
Args:
message: Prompt message
default: Default value as string (e.g., "0,1,2,3")
min_val: Minimum allowed value for each integer (optional)
max_val: Maximum allowed value for each integer (optional)
validator: Custom validation function that returns (is_valid, error_message) (optional)
Returns:
List of validated integers
"""
while True:
# Get input
user_input = Prompt.ask(message, default=default)
# Clean input: support Chinese comma and spaces
user_input_cleaned = user_input.replace("", ",").replace(" ", "")
# Try to parse as integers
try:
values = [int(x.strip()) for x in user_input_cleaned.split(",") if x.strip()]
except ValueError:
console.print(f"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]")
console.print()
continue
# Validate each value's range
invalid_values = []
for value in values:
if min_val is not None and value < min_val:
invalid_values.append(value)
elif max_val is not None and value > max_val:
invalid_values.append(value)
if invalid_values:
if min_val is not None and max_val is not None:
console.print(f"[red]✗ Invalid value(s): {invalid_values}[/red]")
console.print(f"[yellow]Valid range: {min_val}-{max_val}[/yellow]")
elif min_val is not None:
console.print(f"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]")
elif max_val is not None:
console.print(f"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]")
console.print()
continue
# Custom validation
if validator is not None:
is_valid, error_msg = validator(values)
if not is_valid:
console.print(f"[red]✗ {error_msg}[/red]")
console.print()
continue
# All validations passed
return values

View File

@@ -0,0 +1,207 @@
#!/usr/bin/env python3
"""
KV Cache Size Calculator for SGLang
This script calculates the KV cache size in GB for a given model and number of tokens.
It follows the same logic as in sglang/srt/model_executor/model_runner.py
"""
import os
import sys
import torch
from transformers import AutoConfig
# Add sglang to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
from sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
def get_dtype_bytes(dtype_str: str) -> int:
"""Get the number of bytes for a given dtype string."""
dtype_map = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"float8_e4m3fn": 1,
"float8_e5m2": 1,
"auto": 2, # Usually defaults to bfloat16
}
return dtype_map.get(dtype_str, 2)
def get_kv_size_gb(
model_path: str,
max_total_tokens: int,
tp: int = 1,
dtype: str = "auto",
verbose: bool = True,
) -> dict:
"""
Calculate the KV cache size in GB for a given model and number of tokens.
Args:
model_path: Path to the model
max_total_tokens: Maximum number of tokens to cache
tp: Tensor parallelism size
dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)
verbose: Whether to print detailed information
Returns:
dict: Dictionary containing calculation details
"""
# Load model config
model_config = ModelConfig(model_path, dtype=dtype)
hf_config = model_config.hf_config
# Determine dtype bytes
dtype_bytes = get_dtype_bytes(dtype)
if dtype == "auto":
# Auto dtype usually becomes bfloat16
dtype_bytes = 2
# Number of layers
num_layers = model_config.num_attention_layers
# Check if it's MLA (Multi-head Latent Attention) model
is_mla = hasattr(model_config, "attention_arch") and model_config.attention_arch.name == "MLA"
result = {
"model_path": model_path,
"max_total_tokens": max_total_tokens,
"tp": tp,
"dtype": dtype,
"dtype_bytes": dtype_bytes,
"num_layers": num_layers,
"is_mla": is_mla,
}
if is_mla:
# MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)
kv_lora_rank = model_config.kv_lora_rank
qk_rope_head_dim = model_config.qk_rope_head_dim
# Calculate cell size (per token)
cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes
result.update(
{
"kv_lora_rank": kv_lora_rank,
"qk_rope_head_dim": qk_rope_head_dim,
"cell_size_bytes": cell_size,
}
)
# Check if it's NSA (Native Sparse Attention) model
if is_deepseek_nsa(hf_config):
index_head_dim = get_nsa_index_head_dim(hf_config)
indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4
indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)
indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes
cell_size += indexer_cell_size
result.update(
{
"is_nsa": True,
"index_head_dim": index_head_dim,
"indexer_cell_size_bytes": indexer_cell_size,
"total_cell_size_bytes": cell_size,
}
)
else:
result["is_nsa"] = False
else:
# Standard MHA models
num_kv_heads = model_config.get_num_kv_heads(tp)
head_dim = model_config.head_dim
v_head_dim = model_config.v_head_dim
# Calculate cell size (per token)
cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes
result.update(
{
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"v_head_dim": v_head_dim,
"cell_size_bytes": cell_size,
}
)
# Calculate total KV cache size
total_size_bytes = max_total_tokens * cell_size
total_size_gb = total_size_bytes / (1024**3)
# For MHA models with separate K and V buffers
if not is_mla:
k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes
v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes
k_size_gb = k_size_bytes / (1024**3)
v_size_gb = v_size_bytes / (1024**3)
result.update(
{
"k_size_gb": k_size_gb,
"v_size_gb": v_size_gb,
}
)
result.update(
{
"total_size_bytes": total_size_bytes,
"total_size_gb": total_size_gb,
}
)
if verbose:
print(f"Model: {model_path}")
print(f"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}")
print(f"Architecture: {'MLA' if is_mla else 'MHA'}")
print(f"Layers: {num_layers}")
if is_mla:
print(f"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}")
if result.get("is_nsa"):
print(f"NSA Index Head Dim: {index_head_dim}")
print(
f"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})"
)
else:
print(f"Cell size: {cell_size} bytes")
else:
print(f"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}")
print(f"Cell size: {cell_size} bytes")
print(f"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB")
print(f"Total KV Cache Size: {total_size_gb:.2f} GB")
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="Calculate KV cache size for a model")
parser.add_argument("model_path", help="Path to the model")
parser.add_argument("max_total_tokens", type=int, help="Maximum number of tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
parser.add_argument("--dtype", type=str, default="auto", help="Data type (auto, float16, bfloat16, etc.)")
parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")
args = parser.parse_args()
result = get_kv_size_gb(
args.model_path,
args.max_total_tokens,
tp=args.tp,
dtype=args.dtype,
verbose=not args.quiet,
)
if args.quiet:
print(f"{result['total_size_gb']:.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,250 @@
"""
Model Discovery Utilities
Shared functions for discovering and registering new models across different commands.
"""
from typing import List, Optional, Tuple
from pathlib import Path
from rich.console import Console
from kt_kernel.cli.utils.model_scanner import (
discover_models,
scan_directory_for_models,
ScannedModel,
)
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel
console = Console()
def discover_and_register_global(
min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = "en"
) -> Tuple[int, int, List[UserModel]]:
"""
Perform global model discovery and register new models.
Args:
min_size_gb: Minimum model size in GB
max_depth: Maximum search depth
show_progress: Whether to show progress messages
lang: Language for messages ("en" or "zh")
Returns:
Tuple of (total_found, new_found, registered_models)
"""
registry = UserModelRegistry()
if show_progress:
if lang == "zh":
console.print("[dim]正在扫描系统中的模型权重这可能需要30-60秒...[/dim]")
else:
console.print("[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]")
# Global scan
all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth)
# Filter out existing models
new_models = []
for model in all_models:
if not registry.find_by_path(model.path):
new_models.append(model)
# Register new models
registered = []
for model in new_models:
user_model = _create_and_register_model(registry, model)
if user_model:
registered.append(user_model)
return len(all_models), len(new_models), registered
def discover_and_register_path(
path: str,
min_size_gb: float = 2.0,
existing_paths: Optional[set] = None,
show_progress: bool = True,
lang: str = "en",
) -> Tuple[int, int, List[UserModel]]:
"""
Discover models in a specific path and register new ones.
Args:
path: Directory path to scan
min_size_gb: Minimum model file size in GB
existing_paths: Set of already discovered paths in this session (optional)
show_progress: Whether to show progress messages
lang: Language for messages ("en" or "zh")
Returns:
Tuple of (total_found, new_found, registered_models)
"""
registry = UserModelRegistry()
if show_progress:
if lang == "zh":
console.print(f"[dim]正在扫描 {path}...[/dim]")
else:
console.print(f"[dim]Scanning {path}...[/dim]")
# Scan directory
model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb)
if not model_info:
return 0, 0, []
# Convert to ScannedModel and filter
new_models = []
for dir_path, (format_type, size_bytes, file_count, files) in model_info.items():
# Check if already in registry
if registry.find_by_path(dir_path):
continue
# Check if already discovered in this session
if existing_paths and dir_path in existing_paths:
continue
model = ScannedModel(
path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files
)
new_models.append(model)
# Register new models
registered = []
for model in new_models:
user_model = _create_and_register_model(registry, model)
if user_model:
registered.append(user_model)
return len(model_info), len(new_models), registered
def _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]:
"""
Create a UserModel from ScannedModel and register it.
Handles name conflicts by suggesting a unique name (e.g., model-2, model-3).
Automatically detects repo_id from README.md YAML frontmatter.
Automatically detects and caches MoE information for safetensors models.
Args:
registry: UserModelRegistry instance
scanned_model: ScannedModel to register
Returns:
Registered UserModel or None if failed
"""
# Use suggest_name to get a unique name (adds -2, -3, etc. if needed)
unique_name = registry.suggest_name(scanned_model.folder_name)
user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format)
# Auto-detect repo_id from README.md (only YAML frontmatter)
try:
from kt_kernel.cli.utils.repo_detector import detect_repo_for_model
repo_info = detect_repo_for_model(scanned_model.path)
if repo_info:
repo_id, repo_type = repo_info
user_model.repo_id = repo_id
user_model.repo_type = repo_type
except Exception:
# Silently continue if detection fails
pass
# Auto-detect MoE information for safetensors models
if scanned_model.format == "safetensors":
try:
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
moe_result = analyze_moe_model(scanned_model.path, use_cache=True)
if moe_result and moe_result.get("is_moe"):
user_model.is_moe = True
user_model.moe_num_experts = moe_result.get("num_experts")
user_model.moe_num_experts_per_tok = moe_result.get("num_experts_per_tok")
else:
user_model.is_moe = False
except Exception:
# Silently continue if MoE detection fails
# is_moe will remain None
pass
try:
registry.add_model(user_model)
return user_model
except Exception:
# Should not happen since we used suggest_name, but handle gracefully
return None
def format_discovery_summary(
total_found: int,
new_found: int,
registered: List[UserModel],
lang: str = "en",
show_models: bool = True,
max_show: int = 10,
) -> None:
"""
Print formatted discovery summary.
Args:
total_found: Total models found
new_found: New models found
registered: List of registered UserModel objects
lang: Language ("en" or "zh")
show_models: Whether to show model list
max_show: Maximum models to show
"""
console.print()
if new_found == 0:
if total_found > 0:
if lang == "zh":
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,所有模型均已在列表中")
else:
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, all already in the list")
else:
if lang == "zh":
console.print("[yellow]未找到模型[/yellow]")
else:
console.print("[yellow]No models found[/yellow]")
return
# Show summary
if lang == "zh":
console.print(f"[green]✓[/green] 扫描完成:找到 {total_found} 个模型,其中 {new_found} 个为新模型")
else:
console.print(f"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new")
# Show registered count
if len(registered) > 0:
if lang == "zh":
console.print(f"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表")
else:
console.print(f"[green]✓[/green] Successfully added {len(registered)} new models to list")
# Show model list
if show_models and registered:
console.print()
if lang == "zh":
console.print(f"[dim]新发现的模型(前{max_show}个):[/dim]")
else:
console.print(f"[dim]Newly discovered models (first {max_show}):[/dim]")
for i, model in enumerate(registered[:max_show], 1):
# Get size from registry or estimate
size_str = "?.? GB"
# Try to find the ScannedModel to get size
# For now just show name and path
console.print(f" {i}. {model.name} ({model.format})")
console.print(f" [dim]{model.path}[/dim]")
if len(registered) > max_show:
remaining = len(registered) - max_show
if lang == "zh":
console.print(f" [dim]... 还有 {remaining} 个新模型[/dim]")
else:
console.print(f" [dim]... and {remaining} more new models[/dim]")

View File

@@ -0,0 +1,790 @@
"""
Model Scanner
Scans directories for model files (safetensors, gguf) and identifies models
"""
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set, Tuple, Dict
from collections import defaultdict
import os
import subprocess
import json
@dataclass
class ScannedModel:
"""Temporary structure for scanned model information"""
path: str # Absolute path to model directory
format: str # "safetensors" | "gguf" | "mixed"
size_bytes: int # Total size in bytes
file_count: int # Number of model files
files: List[str] # List of model file names
@property
def size_gb(self) -> float:
"""Get size in GB"""
return self.size_bytes / (1024**3)
@property
def folder_name(self) -> str:
"""Get the folder name (default model name)"""
return Path(self.path).name
class ModelScanner:
"""Scanner for discovering models in directory trees"""
def __init__(self, min_size_gb: float = 10.0):
"""
Initialize scanner
Args:
min_size_gb: Minimum folder size in GB to be considered a model
"""
self.min_size_bytes = int(min_size_gb * 1024**3)
def scan_directory(
self, base_path: Path, exclude_paths: Optional[Set[str]] = None
) -> Tuple[List[ScannedModel], List[str]]:
"""
Scan directory tree for models
Args:
base_path: Root directory to scan
exclude_paths: Set of absolute paths to exclude from results
Returns:
Tuple of (valid_models, warnings)
- valid_models: List of ScannedModel instances
- warnings: List of warning messages
"""
if not base_path.exists():
raise ValueError(f"Path does not exist: {base_path}")
if not base_path.is_dir():
raise ValueError(f"Path is not a directory: {base_path}")
exclude_paths = exclude_paths or set()
results: List[ScannedModel] = []
warnings: List[str] = []
# Walk the directory tree
for root, dirs, files in os.walk(base_path):
root_path = Path(root).resolve()
# Skip if already registered
if str(root_path) in exclude_paths:
dirs[:] = [] # Don't descend into this directory
continue
# Check for model files
safetensors_files = [f for f in files if f.endswith(".safetensors")]
gguf_files = [f for f in files if f.endswith(".gguf")]
if not safetensors_files and not gguf_files:
continue # No model files in this directory
# Calculate total size
model_files = safetensors_files + gguf_files
total_size = self._calculate_total_size(root_path, model_files)
# Check if size meets minimum threshold
if total_size < self.min_size_bytes:
continue # Too small, but keep scanning subdirectories
# Detect format
if safetensors_files and gguf_files:
# Mixed format - issue warning
warnings.append(
f"Mixed format detected in {root_path}: "
f"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. "
"Please separate into different folders and re-scan."
)
dirs[:] = [] # Don't descend into mixed format directories
continue
# Determine format
format_type = "safetensors" if safetensors_files else "gguf"
# Create scanned model
scanned = ScannedModel(
path=str(root_path),
format=format_type,
size_bytes=total_size,
file_count=len(model_files),
files=model_files,
)
results.append(scanned)
# Continue scanning subdirectories - they might also contain models
# Each subdirectory will be independently checked for size >= 10GB
return results, warnings
def scan_single_path(self, path: Path) -> Optional[ScannedModel]:
"""
Scan a single path for model files
Args:
path: Path to scan
Returns:
ScannedModel instance or None if not a valid model
"""
if not path.exists() or not path.is_dir():
return None
# Find model files
safetensors_files = list(path.glob("*.safetensors"))
gguf_files = list(path.glob("*.gguf"))
if not safetensors_files and not gguf_files:
return None
# Check for mixed format
if safetensors_files and gguf_files:
raise ValueError(
f"Mixed format detected: {len(safetensors_files)} safetensors + "
f"{len(gguf_files)} gguf files. Please use a single format."
)
# Calculate size
model_files = [f.name for f in safetensors_files + gguf_files]
total_size = self._calculate_total_size(path, model_files)
# Determine format
format_type = "safetensors" if safetensors_files else "gguf"
return ScannedModel(
path=str(path.resolve()),
format=format_type,
size_bytes=total_size,
file_count=len(model_files),
files=model_files,
)
def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int:
"""
Calculate total size of specified files in directory
Args:
directory: Directory containing the files
filenames: List of filenames to sum
Returns:
Total size in bytes
"""
total = 0
for filename in filenames:
file_path = directory / filename
if file_path.exists():
try:
total += file_path.stat().st_size
except OSError:
# File might be inaccessible, skip it
pass
return total
# Convenience functions
def scan_directory(
base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None
) -> Tuple[List[ScannedModel], List[str]]:
"""
Convenience function to scan a directory
Args:
base_path: Root directory to scan
min_size_gb: Minimum folder size in GB
exclude_paths: Set of paths to exclude
Returns:
Tuple of (models, warnings)
"""
scanner = ModelScanner(min_size_gb=min_size_gb)
return scanner.scan_directory(base_path, exclude_paths)
def scan_single_path(path: Path) -> Optional[ScannedModel]:
"""
Convenience function to scan a single path
Args:
path: Path to scan
Returns:
ScannedModel or None
"""
scanner = ModelScanner()
return scanner.scan_single_path(path)
def format_size(size_bytes: int) -> str:
"""
Format size in bytes to human-readable string
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., "42.3 GB")
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} PB"
# ===== Fast Scanning with Find Command and Tree-based Root Detection =====
def find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]:
"""
Use find command to quickly locate files
Args:
mount_point: Starting directory
pattern: File pattern (e.g., "config.json", "*.gguf")
max_depth: Maximum directory depth (default: 6)
timeout: Command timeout in seconds
Returns:
List of absolute file paths
"""
try:
# Use shell=True to redirect stderr to /dev/null, ignoring permission errors
result = subprocess.run(
f'find "{mount_point}" -maxdepth {max_depth} -name "{pattern}" -type f 2>/dev/null',
capture_output=True,
text=True,
timeout=timeout,
shell=True,
)
# Return results even if returncode is non-zero (due to permission errors)
# As long as we got some output
if result.stdout:
return [line.strip() for line in result.stdout.strip().split("\n") if line.strip()]
return []
except (subprocess.TimeoutExpired, FileNotFoundError):
return []
def is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]:
"""
Check if a directory is a valid model directory
Args:
directory: Path to check
min_size_gb: Minimum size in GB
Returns:
(is_valid, model_type) where model_type is "safetensors", "gguf", or None
"""
if not directory.exists() or not directory.is_dir():
return False, None
has_config = (directory / "config.json").exists()
safetensors_files = list(directory.glob("*.safetensors"))
gguf_files = list(directory.glob("*.gguf"))
# Determine model type
model_type = None
if (has_config and safetensors_files) or safetensors_files:
model_type = "safetensors"
elif gguf_files:
model_type = "gguf"
else:
return False, None
# Check size - only count model files (fast!)
total_size = 0
if model_type == "safetensors":
for f in safetensors_files:
try:
total_size += f.stat().st_size
except OSError:
pass
else: # gguf
for f in gguf_files:
try:
total_size += f.stat().st_size
except OSError:
pass
size_gb = total_size / (1024**3)
if size_gb < min_size_gb:
return False, None
return True, model_type
def scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]:
"""
Fast scan for all model paths using find command
Args:
mount_points: List of mount points to scan
min_size_gb: Minimum model size in GB
max_depth: Maximum search depth (default: 6)
Returns:
List of valid model directory paths
"""
model_paths = set()
for mount in mount_points:
if not os.path.exists(mount):
continue
# Find all config.json files
config_files = find_files_fast(mount, "config.json", max_depth=max_depth)
for config_path in config_files:
model_dir = Path(config_path).parent
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
if is_valid:
model_paths.add(str(model_dir.resolve()))
# Find all *.gguf files
gguf_files = find_files_fast(mount, "*.gguf", max_depth=max_depth)
for gguf_path in gguf_files:
model_dir = Path(gguf_path).parent
is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)
if is_valid:
model_paths.add(str(model_dir.resolve()))
return sorted(model_paths)
def get_root_subdirs() -> List[str]:
"""
Get subdirectories of / that are worth scanning
Filters out system paths only
Returns:
List of directories to scan
"""
# System paths to exclude
excluded = {
"dev",
"proc",
"sys",
"run",
"boot",
"tmp",
"usr",
"lib",
"lib64",
"bin",
"sbin",
"etc",
"opt",
"var",
"snap",
}
scan_dirs = []
try:
for entry in os.scandir("/"):
if not entry.is_dir():
continue
# Skip excluded paths
if entry.name in excluded:
continue
scan_dirs.append(entry.path)
except PermissionError:
pass
return sorted(scan_dirs)
def scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]:
"""
Scan a directory for models using find command with size filter
Uses find -size +2G to only locate large model files (>=2GB)
Args:
directory: Directory to scan
min_file_size_gb: Minimum individual file size in GB (default: 2.0)
Returns:
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
"""
model_info = {}
# Convert GB to find's format (e.g., 2GB = +2G)
if min_file_size_gb >= 1.0:
size_filter = f"+{int(min_file_size_gb)}G"
else:
size_mb = int(min_file_size_gb * 1024)
size_filter = f"+{size_mb}M"
# 1. Find *.gguf files >= 2GB
gguf_cmd = f'find "{directory}" -name "*.gguf" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
result = subprocess.run(gguf_cmd, shell=True, capture_output=True, text=True, timeout=120)
# Group by directory
gguf_dirs = defaultdict(list)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
file_path, size_str = parts
file_path_obj = Path(file_path)
dir_path = str(file_path_obj.parent)
gguf_dirs[dir_path].append((file_path_obj.name, int(size_str)))
# Add all gguf directories
for dir_path, files in gguf_dirs.items():
total_size = sum(size for _, size in files)
model_info[dir_path] = ("gguf", total_size, len(files), [name for name, _ in files])
# 2. Find *.safetensors files >= 2GB
safetensors_cmd = (
f'find "{directory}" -name "*.safetensors" -type f -size {size_filter} -printf "%p\\t%s\\n" 2>/dev/null'
)
result = subprocess.run(safetensors_cmd, shell=True, capture_output=True, text=True, timeout=120)
# Group by directory
safetensors_dirs = defaultdict(list)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
file_path, size_str = parts
file_path_obj = Path(file_path)
dir_path = str(file_path_obj.parent)
safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str)))
# 3. Check each safetensors directory for config.json
for dir_path, files in safetensors_dirs.items():
if os.path.exists(os.path.join(dir_path, "config.json")):
total_size = sum(size for _, size in files)
model_info[dir_path] = ("safetensors", total_size, len(files), [name for name, _ in files])
return model_info
def scan_all_models_with_info(
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
) -> Dict[str, tuple]:
"""
Fast scan with parallel directory scanning
Strategy:
1. Use provided directories or auto-detect root subdirectories
2. Scan each directory in parallel (one thread per directory)
3. Use find -size +2G to find large model files (>=2GB)
Args:
mount_points: Specific directories to scan, or None to auto-detect from / subdirs
min_size_gb: Not used anymore (kept for API compatibility)
max_depth: Not used anymore (kept for API compatibility)
Returns:
Dict mapping model_path -> (model_type, size_bytes, file_count, files)
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
# Get directories to scan
if mount_points is None:
# Get root subdirectories (exclude system paths)
scan_dirs = get_root_subdirs()
else:
scan_dirs = mount_points
if not scan_dirs:
return {}
model_info = {}
# Scan each directory in parallel (max 8 concurrent)
# Use 2GB threshold to find model files
with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor:
futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs}
for future in as_completed(futures):
try:
dir_results = future.result()
model_info.update(dir_results)
except Exception as e:
# Skip directories with errors
pass
return model_info
def find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]:
"""
Find optimal root paths from model paths using tree-based algorithm
Algorithm:
1. Build path tree with all intermediate paths
2. DFS to calculate f(x) = subtree sum (number of models in subtree)
3. Find roots where f(parent) = f(x) > max(f(children))
Args:
model_paths: List of model directory paths
Returns:
(root_paths, subtree_sizes) where:
- root_paths: List of inferred root directories
- subtree_sizes: Dict mapping each root to number of models
"""
if not model_paths:
return [], {}
# 1. Build path set (including all intermediate paths)
all_paths = set()
model_set = set(model_paths)
for model_path in model_paths:
path = Path(model_path)
for i in range(1, len(path.parts) + 1):
all_paths.add(str(Path(*path.parts[:i])))
# 2. Build parent-child relationships
children_map = defaultdict(list)
for path in all_paths:
path_obj = Path(path)
if len(path_obj.parts) > 1:
parent = str(path_obj.parent)
if parent in all_paths:
children_map[parent].append(path)
# 3. DFS to calculate f(x) and max_child_f(x)
f = {} # path -> subtree sum
max_child_f = {} # path -> max(f(children))
visited = set()
def dfs(path: str) -> int:
if path in visited:
return f[path]
visited.add(path)
# Current node weight (1 if it's a model path, 0 otherwise)
weight = 1 if path in model_set else 0
# Recursively calculate children
children = children_map.get(path, [])
if not children:
# Leaf node
f[path] = weight
max_child_f[path] = 0
return weight
# Calculate f values for all children
children_f_values = [dfs(child) for child in children]
# Calculate f(x) and max_child_f(x)
f[path] = weight + sum(children_f_values)
max_child_f[path] = max(children_f_values) if children_f_values else 0
return f[path]
# Find top-level nodes (no parent in all_paths)
top_nodes = []
for path in all_paths:
parent = str(Path(path).parent)
if parent not in all_paths or parent == path:
top_nodes.append(path)
# Execute DFS from all top nodes
for top in top_nodes:
dfs(top)
# 4. Find root nodes: f(parent) = f(x) >= max(f(children))
# Note: Use >= instead of > to handle the case where a directory contains only one model
candidate_roots = []
for path in all_paths:
# Skip model paths themselves (leaf nodes in model tree)
if path in model_set:
continue
parent = str(Path(path).parent)
# Check condition: f(parent) = f(x) and f(x) >= max(f(children))
if parent in f and f.get(parent, 0) == f.get(path, 0):
if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0:
candidate_roots.append(path)
# 5. Remove redundant roots (prefer deeper paths)
# If a root is an ancestor of another root with the same f value, remove it
roots = []
candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts))
for root in candidate_roots_sorted:
# Check if this root is a parent of any already selected root
is_redundant = False
for selected in roots:
if selected.startswith(root + "/"):
# selected is a child of root
# Only keep root if it has more models (shouldn't happen by algorithm)
if f.get(root, 0) == f.get(selected, 0):
is_redundant = True
break
if not is_redundant:
# Also filter out very shallow paths (< 3 levels)
if len(Path(root).parts) >= 3:
roots.append(root)
# Build subtree sizes for roots
subtree_sizes = {root: f.get(root, 0) for root in roots}
return sorted(roots), subtree_sizes
@dataclass
class ModelRootInfo:
"""Information about a detected model root path"""
path: str
model_count: int
models: List[ScannedModel]
def discover_models(
mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6
) -> List[ScannedModel]:
"""
Discover all model directories on the system
Fast scan using find command to locate all models that meet the criteria
Args:
mount_points: List of mount points to scan (None = auto-detect)
min_size_gb: Minimum model size in GB (default: 10.0)
max_depth: Maximum search depth (default: 6)
Returns:
List of ScannedModel sorted by path
"""
# Auto-detect mount points if not provided
if mount_points is None:
mount_points = _get_mount_points()
# Fast scan with cached info (only scan once!)
model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth)
if not model_info:
return []
# Convert to ScannedModel objects
results = []
for model_path, (model_type, total_size, file_count, files) in model_info.items():
results.append(
ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files)
)
# Sort by path
results.sort(key=lambda m: m.path)
return results
def _get_mount_points() -> List[str]:
"""
Get all valid mount points from /proc/mounts, filtering out system paths
Returns:
List of mount point paths suitable for model storage
(excludes root "/" to avoid scanning entire filesystem)
"""
mount_points = set()
# System paths to exclude (unlikely to contain model files)
excluded_paths = [
"/snap/",
"/proc/",
"/sys/",
"/run/",
"/boot",
"/dev/",
"/usr",
"/lib",
"/lib64",
"/bin",
"/sbin",
"/etc",
"/opt",
"/var",
"/tmp",
]
try:
with open("/proc/mounts", "r") as f:
for line in f:
parts = line.split()
if len(parts) < 3:
continue
device, mount_point, fs_type = parts[0], parts[1], parts[2]
# Filter out pseudo filesystems
pseudo_fs = {
"proc",
"sysfs",
"devpts",
"tmpfs",
"devtmpfs",
"cgroup",
"cgroup2",
"pstore",
"bpf",
"tracefs",
"debugfs",
"hugetlbfs",
"mqueue",
"configfs",
"securityfs",
"fuse.gvfsd-fuse",
"fusectl",
"squashfs",
"overlay", # snap packages
}
if fs_type in pseudo_fs:
continue
# Skip root directory (too large to scan)
if mount_point == "/":
continue
# Filter out system paths
if any(mount_point.startswith(x) for x in excluded_paths):
continue
# Only include if it exists and is readable
if os.path.exists(mount_point) and os.access(mount_point, os.R_OK):
mount_points.add(mount_point)
# If no mount points found, add common data directories
if not mount_points:
# Add /home if it exists and is not already a separate mount point
common_paths = ["/home", "/data", "/mnt"]
for path in common_paths:
if os.path.exists(path) and os.access(path, os.R_OK):
mount_points.add(path)
except (FileNotFoundError, PermissionError):
# Fallback to common paths
mount_points = {"/home", "/mnt", "/data"}
return sorted(mount_points)

View File

@@ -0,0 +1,254 @@
"""
Shared model table builders for consistent UI across commands.
Provides reusable table construction functions for displaying models
in kt model list, kt quant, kt run, etc.
"""
from typing import List, Optional, Tuple
from pathlib import Path
from rich.table import Table
from rich.console import Console
import json
def format_model_size(model_path: Path, format_type: str) -> str:
"""Calculate and format model size."""
from kt_kernel.cli.utils.model_scanner import format_size
try:
if format_type == "safetensors":
files = list(model_path.glob("*.safetensors"))
elif format_type == "gguf":
files = list(model_path.glob("*.gguf"))
else:
return "[dim]-[/dim]"
total_size = sum(f.stat().st_size for f in files if f.exists())
return format_size(total_size)
except Exception:
return "[dim]-[/dim]"
def format_repo_info(model) -> str:
"""Format repository information."""
if model.repo_id:
repo_abbr = "hf" if model.repo_type == "huggingface" else "ms"
return f"{repo_abbr}:{model.repo_id}"
return "[dim]-[/dim]"
def format_sha256_status(model, status_map: dict) -> str:
"""Format SHA256 verification status."""
return status_map.get(model.sha256_status or "not_checked", "[dim]?[/dim]")
def build_moe_gpu_table(
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
) -> Tuple[Table, List]:
"""
Build MoE GPU models table.
Args:
models: List of MoE GPU model objects
status_map: SHA256_STATUS_MAP for formatting status
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Exps", justify="center", style="yellow")
table.add_column("Act", justify="center", style="green")
table.add_column("Repository", style="dim", overflow="fold")
table.add_column("SHA256", justify="center")
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "safetensors")
# MoE info
num_experts = str(model.moe_num_experts) if model.moe_num_experts else "[dim]-[/dim]"
num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else "[dim]-[/dim]"
# Repository and SHA256
repo_str = format_repo_info(model)
sha256_str = format_sha256_status(model, status_map)
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str])
table.add_row(*row)
return table, displayed_models
def build_amx_table(
models: List,
status_map: dict = None, # Kept for API compatibility but not used
show_index: bool = True,
start_index: int = 1,
show_linked_gpus: bool = False,
gpu_models: Optional[List] = None,
) -> Tuple[Table, List]:
"""
Build AMX models table.
Note: AMX models are locally quantized, so no SHA256 verification column.
Args:
models: List of AMX model objects
status_map: (Unused - kept for API compatibility)
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
show_linked_gpus: Whether to show sub-rows for linked GPU models
gpu_models: List of GPU models (required if show_linked_gpus=True)
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Method", justify="center", style="yellow")
table.add_column("NUMA", justify="center", style="green")
table.add_column("Source", style="dim", overflow="fold")
# Build reverse map if needed
amx_used_by_gpu = {}
if show_linked_gpus and gpu_models:
for model in models:
if model.gpu_model_ids:
gpu_names = []
for gpu_id in model.gpu_model_ids:
for gpu_model in gpu_models:
if gpu_model.id == gpu_id:
gpu_names.append(gpu_model.name)
break
if gpu_names:
amx_used_by_gpu[model.id] = gpu_names
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "safetensors")
# Read metadata from config.json or UserModel fields
method_from_config = None
numa_from_config = None
try:
config_path = Path(model.path) / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
amx_quant = config.get("amx_quantization", {})
if amx_quant.get("converted"):
method_from_config = amx_quant.get("method")
numa_from_config = amx_quant.get("numa_count")
except Exception:
pass
# Priority: UserModel fields > config.json > ?
method_display = (
model.amx_quant_method.upper()
if model.amx_quant_method
else method_from_config.upper() if method_from_config else "[dim]?[/dim]"
)
numa_display = (
str(model.amx_numa_nodes)
if model.amx_numa_nodes
else str(numa_from_config) if numa_from_config else "[dim]?[/dim]"
)
source_display = model.amx_source_model or "[dim]-[/dim]"
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, method_display, numa_display, source_display])
table.add_row(*row)
# Add sub-row showing linked GPUs
if show_linked_gpus and model.id in amx_used_by_gpu:
gpu_list = amx_used_by_gpu[model.id]
gpu_names_str = ", ".join([f"[dim]{name}[/dim]" for name in gpu_list])
sub_row = []
if show_index:
sub_row.append("")
sub_row.extend([f" [dim]↳ GPU: {gpu_names_str}[/dim]", "", "", "", "", ""])
table.add_row(*sub_row, style="dim")
return table, displayed_models
def build_gguf_table(
models: List, status_map: dict, show_index: bool = True, start_index: int = 1
) -> Tuple[Table, List]:
"""
Build GGUF models table.
Args:
models: List of GGUF model objects
status_map: SHA256_STATUS_MAP for formatting status
show_index: Whether to show # column for selection (default: True)
start_index: Starting index number
Returns:
Tuple of (Table object, list of models in display order)
"""
table = Table(show_header=True, header_style="bold", show_lines=False)
if show_index:
table.add_column("#", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Path", style="dim", overflow="fold")
table.add_column("Total", justify="right")
table.add_column("Repository", style="dim", overflow="fold")
table.add_column("SHA256", justify="center")
displayed_models = []
for i, model in enumerate(models, start_index):
displayed_models.append(model)
# Calculate size
size_str = format_model_size(Path(model.path), "gguf")
# Repository and SHA256
repo_str = format_repo_info(model)
sha256_str = format_sha256_status(model, status_map)
row = []
if show_index:
row.append(str(i))
row.extend([model.name, model.path, size_str, repo_str, sha256_str])
table.add_row(*row)
return table, displayed_models

View File

@@ -0,0 +1,918 @@
"""
Model Verifier
SHA256 verification for model integrity
"""
import hashlib
import requests
import os
from pathlib import Path
from typing import Dict, Any, Literal, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:
"""
Compute SHA256 for a single file (worker function for multiprocessing).
Args:
file_path: Path to the file
Returns:
Tuple of (filename, sha256_hash, file_size_mb)
"""
sha256_hash = hashlib.sha256()
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Read file in chunks to handle large files
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks
sha256_hash.update(byte_block)
return file_path.name, sha256_hash.hexdigest(), file_size_mb
def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:
"""
Check if HuggingFace is accessible.
Args:
timeout: Connection timeout in seconds
Returns:
Tuple of (is_accessible, message)
"""
test_url = "https://huggingface.co"
try:
response = requests.head(test_url, timeout=timeout, allow_redirects=True)
if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible"
return True, "HuggingFace is accessible"
except requests.exceptions.Timeout:
return False, f"Connection to {test_url} timed out"
except requests.exceptions.ConnectionError:
return False, f"Cannot connect to {test_url}"
except requests.exceptions.RequestException as e:
return False, f"Connection error: {str(e)}"
return False, "Unknown connection error"
def verify_model_integrity(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
) -> Dict[str, Any]:
"""
Verify local model integrity against remote repository SHA256 hashes.
Verifies all important files:
- *.safetensors (weights)
- *.json (config files)
- *.py (custom model code)
Args:
repo_type: Type of repository ("huggingface" or "modelscope")
repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3")
local_dir: Local directory containing model files
progress_callback: Optional callback function(message: str) for progress updates
Returns:
Dictionary with verification results:
{
"status": "passed" | "failed" | "error",
"files_checked": int,
"files_passed": int,
"files_failed": [list of filenames],
"error_message": str (optional)
}
"""
def report_progress(msg: str):
"""Helper to report progress"""
if progress_callback:
progress_callback(msg)
try:
# Convert repo_type to platform format
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes from remote
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform)
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes with progress
report_progress(f"Calculating SHA256 for local files...")
# Get all local files matching the patterns
local_files = []
for pattern in ["*.safetensors", "*.json", "*.py"]:
local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])
if not local_files:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in local directory: {local_dir}",
}
# Calculate hashes for all files
local_hashes = calculate_local_sha256(
local_dir,
file_pattern="*.safetensors", # Unused when files_list is provided
progress_callback=report_progress,
files_list=local_files,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes with progress
report_progress(f"Comparing {len(official_hashes)} files...")
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
# Handle potential path separators in filename
file_basename = Path(filename).name
# Try to find the file in local hashes
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING")
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH")
else:
files_passed += 1
report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}")
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except ImportError as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope",
"is_network_error": False,
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
) as e:
# Network-related errors - suggest mirror
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def calculate_local_sha256(
local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None
) -> Dict[str, str]:
"""
Calculate SHA256 hashes for files in a directory using parallel processing.
Args:
local_dir: Directory to scan
file_pattern: Glob pattern for files to hash (ignored if files_list is provided)
progress_callback: Optional callback function(message: str) for progress updates
files_list: Optional pre-filtered list of files to hash (overrides file_pattern)
Returns:
Dictionary mapping filename to SHA256 hash
"""
result = {}
if not local_dir.exists():
return result
# Get all files first to report total
if files_list is not None:
files_to_hash = files_list
else:
files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]
total_files = len(files_to_hash)
if total_files == 0:
return result
# Use min(16, total_files) workers to avoid over-spawning processes
max_workers = min(16, total_files)
if progress_callback:
progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation")
# Use ProcessPoolExecutor for CPU-intensive SHA256 computation
completed_count = 0
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}
# Process results as they complete
for future in as_completed(future_to_file):
completed_count += 1
try:
filename, sha256_hash, file_size_mb = future.result()
result[filename] = sha256_hash
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)")
except Exception as e:
file_path = future_to_file[future]
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}")
return result
def fetch_model_sha256(
repo_id: str,
platform: Literal["hf", "ms"],
revision: str | None = None,
use_mirror: bool = False,
timeout: int | None = None,
) -> dict[str, str]:
"""
获取模型仓库中所有重要文件的 sha256 哈希值。
包括:
- *.safetensors (权重文件)
- *.json (配置文件config.json, tokenizer_config.json 等)
- *.py (自定义模型代码modeling.py, configuration.py 等)
Args:
repo_id: 仓库 ID例如 "Qwen/Qwen3-30B-A3B"
platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope)
revision: 版本/分支,默认 HuggingFace 为 "main"ModelScope 为 "master"
use_mirror: 是否使用镜像(仅对 HuggingFace 有效)
timeout: 网络请求超时时间None 表示不设置超时
Returns:
dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."}
"""
if platform == "hf":
# 先尝试直连,失败后自动使用镜像
try:
if use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout)
except Exception as e:
# 如果不是镜像模式且失败了,自动重试使用镜像
if not use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
raise e
elif platform == "ms":
return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout)
else:
raise ValueError(f"不支持的平台: {platform},请使用 'hf''ms'")
def _fetch_from_huggingface(
repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None
) -> dict[str, str]:
"""从 HuggingFace 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
use_mirror: 是否使用镜像hf-mirror.com
timeout: 网络请求超时时间None 表示不设置超时
"""
import os
import socket
# 如果需要使用镜像,设置环境变量
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
from huggingface_hub import HfApi, list_repo_files
try:
api = HfApi()
all_files = list_repo_files(repo_id=repo_id, revision=revision)
# 筛选重要文件:*.safetensors, *.json, *.py
important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))]
if not important_files:
return {}
paths_info = api.get_paths_info(
repo_id=repo_id,
paths=important_files,
revision=revision,
)
result = {}
for file_info in paths_info:
if hasattr(file_info, "lfs") and file_info.lfs is not None:
sha256 = file_info.lfs.sha256
else:
sha256 = getattr(file_info, "blob_id", None)
result[file_info.path] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
# 恢复原始环境变量
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:
"""从 ModelScope 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
timeout: 网络请求超时时间None 表示不设置超时
"""
import socket
from modelscope.hub.api import HubApi
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
try:
api = HubApi()
files_info = api.get_model_files(model_id=repo_id, revision=revision)
result = {}
for file_info in files_info:
filename = file_info.get("Name", file_info.get("Path", ""))
# 筛选重要文件:*.safetensors, *.json, *.py
if filename.endswith((".safetensors", ".json", ".py")):
sha256 = file_info.get("Sha256", file_info.get("sha256", None))
result[filename] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
def verify_model_integrity_with_progress(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
verbose: bool = False,
use_mirror: bool = False,
files_to_verify: list[str] | None = None,
timeout: int | None = None,
) -> Dict[str, Any]:
"""
Verify model integrity with enhanced progress reporting for Rich Progress bars.
This is a wrapper around verify_model_integrity() that provides more detailed
progress information suitable for progress bar display.
The progress_callback receives:
- (message: str, total: int, current: int) for countable operations
- (message: str) for status updates
Args:
repo_type: Repository type ("huggingface" or "modelscope")
repo_id: Repository ID
local_dir: Local directory path
progress_callback: Optional callback for progress updates
verbose: If True, output detailed SHA256 comparison for each file
use_mirror: If True, use HuggingFace mirror (hf-mirror.com)
files_to_verify: Optional list of specific files to verify (for re-verification)
timeout: Network request timeout in seconds (None = no timeout)
"""
def report_progress(msg: str, total=None, current=None):
"""Enhanced progress reporter"""
if progress_callback:
progress_callback(msg, total, current)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes
if files_to_verify:
report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...")
elif use_mirror and platform == "hf":
report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...")
else:
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)
# Filter to only requested files if specified
if files_to_verify:
# Extract clean filenames from files_to_verify (remove markers like "(missing)")
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Filter official_hashes to only include requested files
# Compare using basename since official_hashes keys might have paths
official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No safetensors files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes
local_dir_path = Path(local_dir)
# Only hash the files we need to verify
if files_to_verify:
# Extract clean filenames (without markers)
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Only hash files that match the clean filenames
files_to_hash = [
f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames
]
else:
files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()]
total_files = len(files_to_hash)
if files_to_verify:
report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0)
else:
report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0)
# Progress wrapper for hashing
completed_count = [0] # Use list for mutable closure
def hash_progress_callback(msg: str):
if "Using" in msg and "workers" in msg:
report_progress(msg)
elif "[" in msg and "/" in msg and "]" in msg:
# Progress update like: [1/10] ✓ filename (123.4 MB)
completed_count[0] += 1
report_progress(msg, total=total_files, current=completed_count[0])
# Pass the pre-filtered files_to_hash list
local_hashes = calculate_local_sha256(
local_dir_path,
"*.safetensors",
progress_callback=hash_progress_callback,
files_list=files_to_hash if files_to_verify else None,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes
report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0)
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
file_basename = Path(filename).name
# Find matching local file
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: <missing>",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)",
total=len(official_hashes),
current=idx,
)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)",
total=len(official_hashes),
current=idx,
)
else:
files_passed += 1
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx
)
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
TimeoutError, # Socket timeout from socket.setdefaulttimeout()
OSError, # Network-related OS errors
) as e:
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None:
"""Pre-operation verification of model integrity.
Can be used before running or quantizing models to ensure integrity.
Args:
user_model: UserModel object to verify
user_registry: UserModelRegistry instance
operation_name: Name of the operation (e.g., "running", "quantizing")
"""
from rich.prompt import Prompt, Confirm
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from kt_kernel.cli.i18n import get_lang
from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step
import typer
lang = get_lang()
# Check if already verified
if user_model.sha256_status == "passed":
console.print()
print_info("Model integrity already verified ✓")
console.print()
return
# Model not verified yet
console.print()
console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]")
console.print()
# Check if repo_id exists
if not user_model.repo_id:
# No repo_id - ask user to provide one
console.print("[yellow]No repository ID configured for this model.[/yellow]")
console.print()
console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')")
console.print()
if not Confirm.ask("Would you like to configure repository ID now?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Ask for repo type
console.print()
console.print("Repository type:")
console.print(" [cyan][1][/cyan] HuggingFace")
console.print(" [cyan][2][/cyan] ModelScope")
console.print()
repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1")
repo_type = "huggingface" if repo_type_choice == "1" else "modelscope"
# Ask for repo_id
console.print()
repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)")
# Update model
user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id})
user_model.repo_type = repo_type
user_model.repo_id = repo_id
console.print()
print_success(f"Repository configured: {repo_type}:{repo_id}")
console.print()
# Now ask if user wants to verify
console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]")
console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]")
console.print()
if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Perform verification
console.print()
print_step("Verifying model integrity...")
console.print()
# Check connectivity first
use_mirror = False
if user_model.repo_type == "huggingface":
with console.status("[dim]Checking HuggingFace connectivity...[/dim]"):
is_accessible, message = check_huggingface_connectivity(timeout=5)
if not is_accessible:
print_warning("HuggingFace Connection Failed")
console.print()
console.print(f" {message}")
console.print()
console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]")
console.print()
use_mirror = True
# Fetch remote hashes with timeout
def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):
"""Fetch hashes with timeout."""
executor = ThreadPoolExecutor(max_workers=1)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)
hashes = future.result(timeout=timeout)
executor.shutdown(wait=False)
return (hashes, False)
except (FutureTimeoutError, Exception):
executor.shutdown(wait=False)
return (None, True)
# Try fetching hashes
status = console.status("[dim]Fetching remote hashes...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)
status.stop()
# Handle timeout with fallback
if timed_out and user_model.repo_type == "huggingface" and not use_mirror:
print_warning("HuggingFace Fetch Timeout (10s)")
console.print()
console.print(" [yellow]Trying HuggingFace mirror...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from mirror...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)
status.stop()
if timed_out and user_model.repo_type == "huggingface":
print_warning("HuggingFace Mirror Timeout (10s)")
console.print()
console.print(" [yellow]Fallback to ModelScope...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10)
status.stop()
if not official_hashes or timed_out:
print_error("Failed to fetch remote hashes (network timeout)")
console.print()
console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]")
console.print()
if not Confirm.ask(f"Continue {operation_name} without verification?", default=False):
raise typer.Exit(0)
console.print()
return
console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]")
console.print()
# Calculate local hashes and compare
local_dir = Path(user_model.path)
files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()]
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
) as progress:
# Calculate local hashes
task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash))
def hash_callback(msg):
if "[" in msg and "/" in msg and "]" in msg and "" in msg:
progress.advance(task)
local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback)
progress.remove_task(task)
console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]")
console.print()
# Compare hashes
task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes))
files_failed = []
files_missing = []
files_passed = 0
for filename, official_hash in official_hashes.items():
file_basename = Path(filename).name
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
else:
files_passed += 1
progress.advance(task)
progress.remove_task(task)
console.print()
# Check results
if not files_failed and not files_missing:
# Verification passed
user_registry.update_model(user_model.name, {"sha256_status": "passed"})
print_success("Model integrity verification PASSED ✓")
console.print()
console.print(f" All {files_passed} files verified successfully")
console.print()
else:
# Verification failed
user_registry.update_model(user_model.name, {"sha256_status": "failed"})
print_error(f"Model integrity verification FAILED")
console.print()
console.print(f" ✓ Passed: [green]{files_passed}[/green]")
console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]")
console.print()
if files_missing:
console.print(f" [red]Missing files ({len(files_missing)}):[/red]")
for f in files_missing[:5]:
console.print(f" - {Path(f).name}")
if len(files_missing) > 5:
console.print(f" ... and {len(files_missing) - 5} more")
console.print()
if files_failed:
console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]")
for f in files_failed[:5]:
console.print(f" - {f}")
if len(files_failed) > 5:
console.print(f" ... and {len(files_failed) - 5} more")
console.print()
console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]")
console.print()
console.print("This could cause runtime errors or incorrect inference results.")
console.print()
# Ask if user wants to repair
if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True):
console.print()
print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]")
console.print()
console.print("The verify command will guide you through the repair process.")
raise typer.Exit(0)
# Ask if user wants to continue anyway
console.print()
if not Confirm.ask(
f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False
):
raise typer.Exit(0)
console.print()
print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...")
console.print()

View File

@@ -0,0 +1,57 @@
"""
Port availability checking utilities.
"""
import socket
from typing import Tuple
def is_port_available(host: str, port: int) -> bool:
"""Check if a port is available on the given host.
Args:
host: Host address (e.g., "0.0.0.0", "127.0.0.1")
port: Port number to check
Returns:
True if port is available, False if occupied
"""
try:
# Try to bind to the port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
# Use SO_REUSEADDR to allow binding to recently closed ports
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Try to bind
result = sock.connect_ex((host if host != "0.0.0.0" else "127.0.0.1", port))
sock.close()
# If connect_ex returns 0, port is occupied
# If it returns error (non-zero), port is available
return result != 0
except Exception:
# If any error occurs, assume port is not available
return False
def find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]:
"""Find an available port starting from start_port.
Args:
host: Host address
start_port: Starting port number to check
max_attempts: Maximum number of ports to try
Returns:
Tuple of (found, port_number)
- found: True if an available port was found
- port_number: The available port number (or start_port if not found)
"""
for port in range(start_port, start_port + max_attempts):
if is_port_available(host, port):
return True, port
return False, start_port

View File

@@ -0,0 +1,347 @@
"""
Interactive configuration for kt quant command.
Provides rich, multi-step interactive configuration for model quantization.
"""
from typing import Optional, Dict, Any
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm, IntPrompt
from kt_kernel.cli.i18n import t
console = Console()
def select_model_to_quantize() -> Optional[Any]:
"""Select model to quantize interactively."""
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP
from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table
registry = UserModelRegistry()
all_models = registry.list_models()
# Filter MoE models only (safetensors, not AMX, is_moe=True)
quant_models = []
for model in all_models:
if model.format == "safetensors":
# Skip AMX models
is_amx, _ = is_amx_weights(model.path)
if is_amx:
continue
# Only include MoE models
if model.is_moe:
quant_models.append(model)
if not quant_models:
console.print(f"[yellow]{t('quant_no_moe_models')}[/yellow]")
console.print()
console.print(f" {t('quant_only_moe')}")
console.print()
console.print(f" {t('quant_add_models', command='kt model scan')}")
console.print(f" {t('quant_add_models', command='kt model add <path>')}")
return None
# Display models
console.print()
console.print(f"[bold green]{t('quant_moe_available')}[/bold green]")
console.print()
# Use shared table builder
table, displayed_models = build_moe_gpu_table(
models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1
)
console.print(table)
console.print()
choice = IntPrompt.ask(t("quant_select_model"), default=1, show_choices=False)
if choice < 1 or choice > len(displayed_models):
console.print(f"[red]{t('quant_invalid_choice')}[/red]")
return None
return displayed_models[choice - 1]
def configure_quantization_method() -> Dict[str, str]:
"""Select quantization method and input type."""
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step2_method')}[/bold cyan]", expand=False))
console.print()
# Method selection
console.print(f"[bold]{t('quant_method_label')}[/bold]")
console.print(f" [cyan][1][/cyan] {t('quant_int4_desc')}")
console.print(f" [cyan][2][/cyan] {t('quant_int8_desc')}")
console.print()
method_choice = Prompt.ask(t("quant_select_method"), choices=["1", "2"], default="1")
method = "int4" if method_choice == "1" else "int8"
console.print()
console.print(f"[bold]{t('quant_input_type_label')}[/bold]")
console.print(f" [cyan][1][/cyan] {t('quant_fp8_desc')}")
console.print(f" [cyan][2][/cyan] {t('quant_fp16_desc')}")
console.print(f" [cyan][3][/cyan] {t('quant_bf16_desc')}")
console.print()
input_choice = Prompt.ask(t("quant_select_input_type"), choices=["1", "2", "3"], default="1")
input_type_map = {"1": "fp8", "2": "fp16", "3": "bf16"}
input_type = input_type_map[input_choice]
return {"method": method, "input_type": input_type}
def configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]:
"""Configure CPU parameters."""
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]", expand=False))
console.print()
def clamp(value: int, min_val: int, max_val: int, default: int) -> int:
"""Clamp value to range or return default if out of bounds."""
if min_val <= value <= max_val:
return max(min_val, min(value, max_val))
return default
default_threads = int(max_cores * 0.8)
cpu_threads = IntPrompt.ask(t("quant_cpu_threads_prompt", max=max_cores), default=default_threads)
cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads)
numa_nodes = IntPrompt.ask(t("quant_numa_nodes_prompt", max=max_numa), default=max_numa)
numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa)
# Ask about GPU usage
console.print()
console.print(f"[bold]{t('quant_use_gpu_label')}[/bold]")
console.print(f" [dim]{t('quant_gpu_speedup')}[/dim]")
console.print()
use_gpu = Confirm.ask(t("quant_enable_gpu"), default=True)
return {"cpu_threads": cpu_threads, "numa_nodes": numa_nodes, "use_gpu": use_gpu}
def configure_output_path(model: Any, method: str, numa_nodes: int) -> Path:
"""Configure output path for quantized weights."""
from kt_kernel.cli.config.settings import get_settings
console.print()
console.print(Panel(f"[bold cyan]{t('quant_step4_output')}[/bold cyan]", expand=False))
console.print()
# Generate default output path
model_path = Path(model.path)
method_upper = method.upper()
settings = get_settings()
# Priority: paths.weights > paths.models[0] > model's parent directory
weights_dir = settings.weights_dir
if weights_dir and weights_dir.exists():
# Use configured weights directory (highest priority)
default_output = weights_dir / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
else:
# Use first model storage path
model_paths = settings.get_model_paths()
if model_paths and model_paths[0].exists():
default_output = model_paths[0] / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
else:
# Fallback to model's parent directory
default_output = model_path.parent / f"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}"
console.print(f"[dim]{t('quant_default_path')}[/dim]", default_output)
console.print()
use_default = Confirm.ask(t("quant_use_default"), default=True)
if use_default:
return default_output
custom_path = Prompt.ask(t("quant_custom_path"), default=str(default_output))
return Path(custom_path)
def calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]:
"""
Calculate source model size and estimated quantized size.
Args:
source_path: Path to source model
input_type: Input type (fp8, fp16, bf16)
quant_method: Quantization method (int4, int8)
Returns:
Tuple of (source_size_gb, estimated_quant_size_gb)
"""
# Calculate source model size
try:
total_bytes = sum(f.stat().st_size for f in source_path.glob("*.safetensors") if f.is_file())
source_size_gb = total_bytes / (1024**3)
except Exception:
return 0.0, 0.0
# Bits mapping
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
quant_bits = {"int4": 4, "int8": 8}
input_bit = input_bits.get(input_type, 16)
quant_bit = quant_bits.get(quant_method, 4)
# Estimate: source_size * (quant_bits / input_bits)
ratio = quant_bit / input_bit
estimated_size_gb = source_size_gb * ratio
return source_size_gb, estimated_size_gb
def check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]:
"""
Check available disk space at output path.
Args:
output_path: Target output path
required_size_gb: Required space in GB
Returns:
Tuple of (available_gb, is_sufficient)
is_sufficient is True if available >= required * 1.2
"""
import shutil
try:
# Get parent directory that exists
check_path = output_path.parent if not output_path.exists() else output_path
while not check_path.exists() and check_path != check_path.parent:
check_path = check_path.parent
stat = shutil.disk_usage(check_path)
available_gb = stat.free / (1024**3)
# Check if available space >= required * 1.2 (20% buffer)
is_sufficient = available_gb >= (required_size_gb * 1.2)
return available_gb, is_sufficient
except Exception:
return 0.0, False
def interactive_quant_config() -> Optional[Dict[str, Any]]:
"""
Interactive configuration for kt quant.
Returns configuration dict or None if cancelled.
"""
from kt_kernel.cli.utils.environment import detect_cpu_info
# Get CPU info
cpu_info = detect_cpu_info()
# Step 1: Select model
model = select_model_to_quantize()
if not model:
return None
# Step 1.5: Pre-quantization verification (optional)
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
user_registry = UserModelRegistry()
user_model_obj = user_registry.find_by_path(model.path)
if user_model_obj and user_model_obj.format == "safetensors":
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
# Step 2: Configure quantization method
quant_config = configure_quantization_method()
# Step 3: Configure CPU parameters
cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes) # Use logical threads
# Step 4: Configure output path
output_path = configure_output_path(model, quant_config["method"], cpu_config["numa_nodes"])
# Step 4.5: Check if output path already exists and generate unique name
if output_path.exists():
console.print()
console.print(t("quant_output_exists_warn", path=str(output_path)))
console.print()
# Generate unique name by adding suffix
original_name = output_path.name
parent_dir = output_path.parent
counter = 2
while output_path.exists():
new_name = f"{original_name}-{counter}"
output_path = parent_dir / new_name
counter += 1
console.print(t("quant_using_unique_name", path=str(output_path)))
console.print()
# Step 5: Calculate space requirements and check availability
console.print()
console.print(Panel(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]", expand=False))
console.print()
source_size_gb, estimated_size_gb = calculate_quantized_size(
Path(model.path), quant_config["input_type"], quant_config["method"]
)
available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb)
console.print(f" {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]")
console.print(f" {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]")
console.print(
f" {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]"
)
console.print()
if not is_sufficient:
required_with_buffer = estimated_size_gb * 1.2
console.print(f"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]")
console.print()
console.print(f" {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]")
console.print(f" {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]")
console.print(f" {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]")
console.print()
console.print(f" {t('quant_may_fail')}")
console.print()
if not Confirm.ask(f"[yellow]{t('quant_continue_anyway')}[/yellow]", default=False):
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
return None
console.print()
# Summary and confirmation
console.print()
console.print(Panel(f"[bold cyan]{t('quant_config_summary')}[/bold cyan]", expand=False))
console.print()
console.print(f" {t('quant_summary_model'):<15} {model.name}")
console.print(f" {t('quant_summary_method'):<15} {quant_config['method'].upper()}")
console.print(f" {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}")
console.print(f" {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}")
console.print(f" {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}")
console.print(f" {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}")
console.print(f" {t('quant_summary_output'):<15} {output_path}")
console.print()
if not Confirm.ask(f"[bold green]{t('quant_start_question')}[/bold green]", default=True):
console.print(f"[yellow]{t('quant_cancelled')}[/yellow]")
return None
return {
"model": model,
"method": quant_config["method"],
"input_type": quant_config["input_type"],
"cpu_threads": cpu_config["cpu_threads"],
"numa_nodes": cpu_config["numa_nodes"],
"use_gpu": cpu_config["use_gpu"],
"output_path": output_path,
}

View File

@@ -0,0 +1,364 @@
"""
Repo Detector
Automatically detect repository information from model README.md files
"""
import re
from pathlib import Path
from typing import Optional, Dict, Tuple
import yaml
def parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]:
"""
Parse YAML frontmatter from README.md
Args:
readme_path: Path to README.md file
Returns:
Dictionary of frontmatter data, or None if not found
"""
if not readme_path.exists():
return None
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
# Match YAML frontmatter between --- markers
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not match:
return None
yaml_content = match.group(1)
# Parse YAML
try:
data = yaml.safe_load(yaml_content)
return data if isinstance(data, dict) else None
except yaml.YAMLError:
return None
except Exception as e:
return None
def extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]:
"""
Extract repo_id and repo_type from frontmatter
Args:
frontmatter: Parsed YAML frontmatter dictionary
Returns:
Tuple of (repo_id, repo_type) or None
repo_type is either "huggingface" or "modelscope"
"""
if not frontmatter:
return None
# Priority 1: Extract from license_link (most reliable)
license_link = frontmatter.get("license_link")
if license_link and isinstance(license_link, str):
result = _extract_repo_from_url(license_link)
if result:
return result
# Priority 2: Try to find repo_id from other fields
repo_id = None
# Check base_model field
base_model = frontmatter.get("base_model")
if base_model:
if isinstance(base_model, list) and len(base_model) > 0:
# base_model is a list, take first item
repo_id = base_model[0]
elif isinstance(base_model, str):
repo_id = base_model
# Check model-index field
if not repo_id:
model_index = frontmatter.get("model-index")
if isinstance(model_index, list) and len(model_index) > 0:
first_model = model_index[0]
if isinstance(first_model, dict):
repo_id = first_model.get("name")
# Check model_name field
if not repo_id:
repo_id = frontmatter.get("model_name")
if not repo_id or not isinstance(repo_id, str):
return None
# Validate format: should be "namespace/model-name"
if "/" not in repo_id:
return None
parts = repo_id.split("/")
if len(parts) != 2:
return None
# Determine repo type
repo_type = "huggingface" # Default
# Look for ModelScope indicators
if "modelscope" in repo_id.lower():
repo_type = "modelscope"
# Check tags
tags = frontmatter.get("tags", [])
if isinstance(tags, list):
if "modelscope" in [str(t).lower() for t in tags]:
repo_type = "modelscope"
return (repo_id, repo_type)
def _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]:
"""
Extract repo_id and repo_type from a URL
Supports:
- https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE
- https://modelscope.cn/models/Qwen/Qwen3-30B-A3B
Args:
url: URL string
Returns:
Tuple of (repo_id, repo_type) or None
"""
# HuggingFace pattern: https://huggingface.co/{namespace}/{model}/...
hf_match = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", url)
if hf_match:
namespace = hf_match.group(1)
model_name = hf_match.group(2)
repo_id = f"{namespace}/{model_name}"
return (repo_id, "huggingface")
# ModelScope pattern: https://modelscope.cn/models/{namespace}/{model}
ms_match = re.match(r"https?://(?:www\.)?modelscope\.cn/models/([^/]+)/([^/]+)", url)
if ms_match:
namespace = ms_match.group(1)
model_name = ms_match.group(2)
repo_id = f"{namespace}/{model_name}"
return (repo_id, "modelscope")
return None
def extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]:
"""
Extract repo info by globally searching for URLs in README.md
Args:
readme_path: Path to README.md file
Returns:
Tuple of (repo_id, repo_type) or None if not found
"""
if not readme_path.exists():
return None
try:
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
# Find all HuggingFace URLs
hf_pattern = r"https?://huggingface\.co/([^/\s]+)/([^/\s\)]+)"
hf_matches = re.findall(hf_pattern, content)
# Find all ModelScope URLs
ms_pattern = r"https?://(?:www\.)?modelscope\.cn/models/([^/\s]+)/([^/\s\)]+)"
ms_matches = re.findall(ms_pattern, content)
# Collect all found repos with their types
found_repos = []
for namespace, model_name in hf_matches:
# Skip common non-repo paths
if namespace.lower() in ["docs", "blog", "spaces", "datasets"]:
continue
if model_name.lower() in ["tree", "blob", "raw", "resolve", "discussions"]:
continue
repo_id = f"{namespace}/{model_name}"
found_repos.append((repo_id, "huggingface"))
for namespace, model_name in ms_matches:
repo_id = f"{namespace}/{model_name}"
found_repos.append((repo_id, "modelscope"))
if not found_repos:
return None
# If multiple different repos found, use the last one
# First, deduplicate
seen = {}
for repo_id, repo_type in found_repos:
seen[repo_id] = repo_type # Will keep the last occurrence
# Get the last unique repo
if seen:
# Use the last item from found_repos that's unique
last_unique = None
for repo_id, repo_type in found_repos:
if repo_id in seen:
last_unique = (repo_id, repo_type)
return last_unique
return None
except Exception as e:
return None
def detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]:
"""
Detect repository information for a model
Strategy:
Only extract from YAML frontmatter metadata in README.md
(Removed global URL search to avoid false positives)
Args:
model_path: Path to model directory
Returns:
Tuple of (repo_id, repo_type) or None if not detected
"""
model_dir = Path(model_path)
if not model_dir.exists() or not model_dir.is_dir():
return None
# Look for README.md
readme_path = model_dir / "README.md"
if not readme_path.exists():
return None
# Only parse YAML frontmatter (no fallback to global search)
frontmatter = parse_readme_frontmatter(readme_path)
if frontmatter:
return extract_repo_from_frontmatter(frontmatter)
return None
def scan_models_for_repo(model_list) -> Dict:
"""
Scan a list of models and detect repo information
Args:
model_list: List of UserModel objects
Returns:
Dictionary with scan results:
{
'detected': [(model, repo_id, repo_type), ...],
'not_detected': [model, ...],
'skipped': [model, ...] # Already has repo_id
}
"""
results = {"detected": [], "not_detected": [], "skipped": []}
for model in model_list:
# Skip if already has repo_id
if model.repo_id:
results["skipped"].append(model)
continue
# Only process safetensors and gguf models
if model.format not in ["safetensors", "gguf"]:
results["skipped"].append(model)
continue
# Try to detect repo
repo_info = detect_repo_for_model(model.path)
if repo_info:
repo_id, repo_type = repo_info
results["detected"].append((model, repo_id, repo_type))
else:
results["not_detected"].append(model)
return results
def format_detection_report(results: Dict) -> str:
"""
Format scan results into a readable report
Args:
results: Results from scan_models_for_repo()
Returns:
Formatted string report
"""
lines = []
lines.append("=" * 80)
lines.append("Auto-Detection Report")
lines.append("=" * 80)
lines.append("")
# Detected
if results["detected"]:
lines.append(f"✓ Detected repository information ({len(results['detected'])} models):")
lines.append("")
for model, repo_id, repo_type in results["detected"]:
lines.append(f"{model.name}")
lines.append(f" Path: {model.path}")
lines.append(f" Repo: {repo_id} ({repo_type})")
lines.append("")
# Not detected
if results["not_detected"]:
lines.append(f"✗ No repository information found ({len(results['not_detected'])} models):")
lines.append("")
for model in results["not_detected"]:
lines.append(f"{model.name}")
lines.append(f" Path: {model.path}")
lines.append("")
# Skipped
if results["skipped"]:
lines.append(f"⊘ Skipped ({len(results['skipped'])} models):")
lines.append(f" (Already have repo_id or not safetensors/gguf format)")
lines.append("")
lines.append("=" * 80)
lines.append(
f"Summary: {len(results['detected'])} detected, "
f"{len(results['not_detected'])} not detected, "
f"{len(results['skipped'])} skipped"
)
lines.append("=" * 80)
return "\n".join(lines)
def apply_detection_results(results: Dict, registry) -> int:
"""
Apply detected repo information to models in registry
Args:
results: Results from scan_models_for_repo()
registry: UserModelRegistry instance
Returns:
Number of models updated
"""
updated_count = 0
for model, repo_id, repo_type in results["detected"]:
success = registry.update_model(model.name, {"repo_id": repo_id, "repo_type": repo_type})
if success:
updated_count += 1
return updated_count

View File

@@ -0,0 +1,111 @@
"""
Configuration save/load for kt run command.
Manages saved run configurations bound to specific models.
"""
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime
import yaml
CONFIG_FILE = Path.home() / ".ktransformers" / "run_configs.yaml"
class RunConfigManager:
"""Manager for saved run configurations."""
def __init__(self):
self.config_file = CONFIG_FILE
self._ensure_config_file()
def _ensure_config_file(self):
"""Ensure config file exists."""
if not self.config_file.exists():
self.config_file.parent.mkdir(parents=True, exist_ok=True)
self._save_data({"version": "1.0", "configs": {}})
def _load_data(self) -> Dict:
"""Load raw config data."""
try:
with open(self.config_file, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {"version": "1.0", "configs": {}}
except Exception:
return {"version": "1.0", "configs": {}}
def _save_data(self, data: Dict):
"""Save raw config data."""
with open(self.config_file, "w", encoding="utf-8") as f:
yaml.dump(data, f, allow_unicode=True, default_flow_style=False)
def list_configs(self, model_id: str) -> List[Dict[str, Any]]:
"""List all saved configs for a model.
Returns:
List of config dicts with 'config_name' and other fields.
"""
data = self._load_data()
configs = data.get("configs", {}).get(model_id, [])
return configs if isinstance(configs, list) else []
def save_config(self, model_id: str, config: Dict[str, Any]):
"""Save a configuration for a model.
Args:
model_id: Model ID to bind config to
config: Configuration dict with all run parameters
"""
data = self._load_data()
if "configs" not in data:
data["configs"] = {}
if model_id not in data["configs"]:
data["configs"][model_id] = []
# Add timestamp
config["created_at"] = datetime.now().isoformat()
# Append config
data["configs"][model_id].append(config)
self._save_data(data)
def delete_config(self, model_id: str, config_index: int) -> bool:
"""Delete a saved configuration.
Args:
model_id: Model ID
config_index: Index of config to delete (0-based)
Returns:
True if deleted, False if not found
"""
data = self._load_data()
if model_id not in data.get("configs", {}):
return False
configs = data["configs"][model_id]
if config_index < 0 or config_index >= len(configs):
return False
configs.pop(config_index)
self._save_data(data)
return True
def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]:
"""Get a specific saved configuration.
Args:
model_id: Model ID
config_index: Index of config to get (0-based)
Returns:
Config dict or None if not found
"""
configs = self.list_configs(model_id)
if config_index < 0 or config_index >= len(configs):
return None
return configs[config_index]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,459 @@
"""
Tuna engine for auto-tuning GPU experts configuration.
Automatically finds the maximum viable num-gpu-experts through binary search
by testing actual server launches with different configurations.
"""
import json
import math
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Optional
from kt_kernel.cli.utils.console import console, print_error, print_info, print_warning
def get_num_experts(model_path: Path) -> int:
"""
Get the number of experts per layer from model config.
Args:
model_path: Path to the model directory
Returns:
Number of experts per layer
Raises:
ValueError: If config.json not found or num_experts field missing
"""
config_file = model_path / "config.json"
if not config_file.exists():
raise ValueError(f"config.json not found in {model_path}")
try:
config = json.loads(config_file.read_text())
except Exception as e:
raise ValueError(f"Failed to parse config.json: {e}")
# Different models may use different field names
possible_keys = [
"num_experts_per_tok", # DeepSeek
"num_local_experts", # Mixtral
"n_routed_experts", # Qwen
"num_experts", # Generic
]
for key in possible_keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find num_experts field in {config_file}. " f"Tried: {', '.join(possible_keys)}")
def detect_oom(log_line: Optional[str]) -> bool:
"""
Detect OOM (Out Of Memory) errors from log output.
Args:
log_line: A line from server output
Returns:
True if OOM detected, False otherwise
"""
if log_line is None:
return False
log_lower = log_line.lower()
oom_patterns = [
"cuda out of memory",
"out of memory",
"outofmemoryerror",
"oom",
"failed to allocate",
"cumemalloc failed",
"cumemallocasync failed",
"allocation failed",
]
return any(pattern in log_lower for pattern in oom_patterns)
def test_config(
num_gpu_experts: int,
model_path: Path,
config: dict,
verbose: bool = False,
) -> tuple[bool, float]:
"""
Test if a configuration with given num_gpu_experts works.
Args:
num_gpu_experts: Number of GPU experts to test
model_path: Path to the model
config: Configuration dict with all parameters
verbose: Whether to show detailed logs
Returns:
(success: bool, elapsed_time: float)
- success: True if server starts and inference works
- elapsed_time: Time taken for the test
"""
start_time = time.time()
# Use random port to avoid conflicts
test_port = random.randint(30000, 40000)
# Build command
cmd = [
sys.executable,
"-m",
"sglang.launch_server",
"--model",
str(model_path),
"--port",
str(test_port),
"--host",
"127.0.0.1",
"--tensor-parallel-size",
str(config["tensor_parallel_size"]),
"--kt-num-gpu-experts",
str(num_gpu_experts),
"--max-total-tokens",
str(config["max_total_tokens"]),
]
# Add kt-kernel options
if config.get("weights_path"):
cmd.extend(["--kt-weight-path", str(config["weights_path"])])
else:
cmd.extend(["--kt-weight-path", str(model_path)])
cmd.extend(
[
"--kt-cpuinfer",
str(config.get("cpu_threads", 64)),
"--kt-threadpool-count",
str(config.get("numa_nodes", 2)),
"--kt-method",
config.get("kt_method", "AMXINT4"),
"--kt-gpu-prefill-token-threshold",
str(config.get("kt_gpu_prefill_threshold", 4096)),
]
)
# Add other SGLang options
if config.get("attention_backend"):
cmd.extend(["--attention-backend", config["attention_backend"]])
cmd.extend(
[
"--trust-remote-code",
"--mem-fraction-static",
str(config.get("mem_fraction_static", 0.98)),
"--chunked-prefill-size",
str(config.get("chunked_prefill_size", 4096)),
"--max-running-requests",
str(config.get("max_running_requests", 1)), # Use 1 for faster testing
"--watchdog-timeout",
str(config.get("watchdog_timeout", 3000)),
"--enable-mixed-chunk",
"--enable-p2p-check",
]
)
# Add disable-shared-experts-fusion if specified
if config.get("disable_shared_experts_fusion"):
cmd.append("--disable-shared-experts-fusion")
# Add extra args
if config.get("extra_args"):
cmd.extend(config["extra_args"])
if verbose:
console.print(f"[dim]Command: {' '.join(cmd)}[/dim]")
# Start process
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=config.get("env"),
)
except Exception as e:
if verbose:
print_error(f"Failed to start process: {e}")
return False, time.time() - start_time
# Monitor process output
timeout = 60 # Maximum 60 seconds to wait
server_ready = False
try:
while time.time() - start_time < timeout:
# Check if process has output
if process.poll() is not None:
# Process exited
if verbose:
print_warning("Process exited early")
return False, time.time() - start_time
# Read output line (non-blocking)
try:
line = process.stdout.readline()
if not line:
time.sleep(0.1)
continue
if verbose:
console.print(f"[dim]{line.rstrip()}[/dim]")
# Fast OOM detection
if detect_oom(line):
if verbose:
print_warning(f"OOM detected: {line.rstrip()}")
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
return False, time.time() - start_time
# Check for startup success
if "Uvicorn running" in line or "Application startup complete" in line:
server_ready = True
break
except Exception as e:
if verbose:
print_warning(f"Error reading output: {e}")
break
if not server_ready:
# Timeout or failed to start
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
return False, time.time() - start_time
# Server is ready, test inference
success = test_inference(test_port, verbose=verbose)
# Cleanup
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2)
return success, time.time() - start_time
except KeyboardInterrupt:
# User cancelled
process.terminate()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
process.kill()
raise
except Exception as e:
if verbose:
print_error(f"Test failed with exception: {e}")
try:
process.terminate()
process.wait(timeout=2)
except:
try:
process.kill()
except:
pass
return False, time.time() - start_time
def test_inference(port: int, verbose: bool = False) -> bool:
"""
Test if the server can handle a simple inference request.
Args:
port: Server port
verbose: Whether to show detailed logs
Returns:
True if inference succeeds, False otherwise
"""
try:
# Wait a bit for server to be fully ready
time.sleep(2)
# Try to import OpenAI client
try:
from openai import OpenAI
except ImportError:
if verbose:
print_warning("OpenAI package not available, skipping inference test")
return True # Assume success if we can't test
client = OpenAI(
base_url=f"http://127.0.0.1:{port}/v1",
api_key="test",
)
# Send a simple test request
response = client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hi"}],
max_tokens=1,
temperature=0,
timeout=10,
)
# Check if we got a valid response
success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None
if verbose:
if success:
print_info(f"Inference test passed: {response.choices[0].message.content}")
else:
print_warning("Inference test failed: no valid response")
return success
except Exception as e:
if verbose:
print_warning(f"Inference test failed: {e}")
return False
def find_max_gpu_experts(
model_path: Path,
config: dict,
verbose: bool = False,
) -> int:
"""
Binary search to find the maximum viable num_gpu_experts.
Args:
model_path: Path to the model
config: Configuration dict
verbose: Whether to show detailed logs
Returns:
Maximum number of GPU experts that works
"""
# Get number of experts from model config
try:
num_experts = get_num_experts(model_path)
except ValueError as e:
print_error(str(e))
raise
console.print()
console.print(f"Binary search range: [0, {num_experts}]")
console.print()
left, right = 0, num_experts
result = 0
iteration = 0
total_iterations = math.ceil(math.log2(num_experts + 1))
while left <= right:
iteration += 1
mid = (left + right) // 2
console.print(f"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... ", end="")
success, elapsed = test_config(mid, model_path, config, verbose=verbose)
if success:
console.print(f"[green]✓ OK[/green] ({elapsed:.1f}s)")
result = mid
left = mid + 1
else:
console.print(f"[red]✗ FAILED[/red] ({elapsed:.1f}s)")
right = mid - 1
return result
def run_tuna(
model_path: Path,
tensor_parallel_size: int,
max_total_tokens: int,
kt_method: str,
verbose: bool = False,
**kwargs,
) -> int:
"""
Run tuna auto-tuning to find optimal num_gpu_experts.
Args:
model_path: Path to the model
tensor_parallel_size: Tensor parallel size
max_total_tokens: Maximum total tokens
kt_method: KT quantization method
verbose: Whether to show detailed logs
**kwargs: Additional configuration parameters
Returns:
Optimal num_gpu_experts value
Raises:
ValueError: If tuning fails completely
"""
# Prepare configuration
config = {
"tensor_parallel_size": tensor_parallel_size,
"max_total_tokens": max_total_tokens,
"kt_method": kt_method,
**kwargs,
}
# Run binary search
try:
result = find_max_gpu_experts(model_path, config, verbose=verbose)
except KeyboardInterrupt:
console.print()
print_warning("Tuning cancelled by user")
raise
console.print()
# Check if even 0 doesn't work
if result == 0:
console.print("[yellow]Testing if gpu-experts=0 is viable...[/yellow]")
success, _ = test_config(0, model_path, config, verbose=verbose)
if not success:
# Even 0 doesn't work
console.print()
print_error("Failed to start server even with all experts on CPU (gpu-experts=0)")
console.print()
console.print("[bold]Possible reasons:[/bold]")
console.print(" • Insufficient GPU memory for base model layers")
console.print(" • max-total-tokens is too large for available VRAM")
console.print(" • Tensor parallel configuration issue")
console.print()
console.print("[bold]Suggestions:[/bold]")
console.print(f" • Reduce --max-total-tokens (current: {max_total_tokens})")
console.print(f" • Reduce --tensor-parallel-size (current: {tensor_parallel_size})")
console.print(" • Use more GPUs or GPUs with more VRAM")
console.print(" • Try a smaller model")
console.print()
raise ValueError("Minimum GPU memory requirements not met")
else:
# 0 works but nothing more
console.print()
print_warning("All experts will run on CPU (gpu-experts=0). " "Performance will be limited by CPU speed.")
return result

View File

@@ -0,0 +1,302 @@
"""
User Model Registry
Manages user-registered models in ~/.ktransformers/user_models.yaml
"""
from dataclasses import dataclass, asdict, field
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
import yaml
# Constants
USER_MODELS_FILE = Path.home() / ".ktransformers" / "user_models.yaml"
REGISTRY_VERSION = "1.0"
@dataclass
class UserModel:
"""Represents a user-registered model"""
name: str # User-editable name (default: folder name)
path: str # Absolute path to model directory
format: str # "safetensors" | "gguf"
id: Optional[str] = None # Unique UUID for this model (auto-generated if None)
repo_type: Optional[str] = None # "huggingface" | "modelscope" | None
repo_id: Optional[str] = None # e.g., "deepseek-ai/DeepSeek-V3"
sha256_status: str = "not_checked" # "not_checked" | "checking" | "passed" | "failed" | "no_repo"
gpu_model_ids: Optional[List[str]] = None # For llamafile/AMX: list of GPU model UUIDs to run with
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
last_verified: Optional[str] = None # ISO format datetime
# MoE information (cached from analyze_moe_model)
is_moe: Optional[bool] = None # True if MoE model, False if non-MoE, None if not analyzed
moe_num_experts: Optional[int] = None # Total number of experts (for MoE models)
moe_num_experts_per_tok: Optional[int] = None # Number of active experts per token (for MoE models)
# AMX quantization metadata (for format == "amx")
amx_source_model: Optional[str] = None # Name of the source MoE model that was quantized
amx_quant_method: Optional[str] = None # "int4" | "int8"
amx_numa_nodes: Optional[int] = None # Number of NUMA nodes used for quantization
def __post_init__(self):
"""Ensure ID is set after initialization"""
if self.id is None:
import uuid
self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for YAML serialization"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "UserModel":
"""Create from dictionary loaded from YAML"""
return cls(**data)
def path_exists(self) -> bool:
"""Check if model path still exists"""
return Path(self.path).exists()
class UserModelRegistry:
"""Manages the user model registry"""
def __init__(self, registry_file: Optional[Path] = None):
"""
Initialize the registry
Args:
registry_file: Path to the registry YAML file (default: USER_MODELS_FILE)
"""
self.registry_file = registry_file or USER_MODELS_FILE
self.models: List[UserModel] = []
self.version = REGISTRY_VERSION
# Ensure directory exists
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
# Load existing registry
self.load()
def load(self) -> None:
"""Load models from YAML file"""
if not self.registry_file.exists():
# Initialize empty registry
self.models = []
self.save() # Create the file
return
try:
with open(self.registry_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not data:
self.models = []
return
# Load version
self.version = data.get("version", REGISTRY_VERSION)
# Load models
models_data = data.get("models", [])
self.models = [UserModel.from_dict(m) for m in models_data]
# Migrate: ensure all models have UUIDs (for backward compatibility)
needs_save = False
for model in self.models:
if model.id is None:
import uuid
model.id = str(uuid.uuid4())
needs_save = True
if needs_save:
self.save()
except Exception as e:
raise RuntimeError(f"Failed to load user model registry: {e}")
def save(self) -> None:
"""Save models to YAML file"""
data = {"version": self.version, "models": [m.to_dict() for m in self.models]}
try:
with open(self.registry_file, "w", encoding="utf-8") as f:
yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
except Exception as e:
raise RuntimeError(f"Failed to save user model registry: {e}")
def add_model(self, model: UserModel) -> None:
"""
Add a model to the registry
Args:
model: UserModel instance to add
Raises:
ValueError: If a model with the same name already exists
"""
if self.check_name_conflict(model.name):
raise ValueError(f"Model with name '{model.name}' already exists")
self.models.append(model)
self.save()
def remove_model(self, name: str) -> bool:
"""
Remove a model from the registry
Args:
name: Name of the model to remove
Returns:
True if model was removed, False if not found
"""
original_count = len(self.models)
self.models = [m for m in self.models if m.name != name]
if len(self.models) < original_count:
self.save()
return True
return False
def update_model(self, name: str, updates: Dict[str, Any]) -> bool:
"""
Update a model's attributes
Args:
name: Name of the model to update
updates: Dictionary of attributes to update
Returns:
True if model was updated, False if not found
"""
model = self.get_model(name)
if not model:
return False
# Update attributes
for key, value in updates.items():
if hasattr(model, key):
setattr(model, key, value)
self.save()
return True
def get_model(self, name: str) -> Optional[UserModel]:
"""
Get a model by name
Args:
name: Name of the model
Returns:
UserModel instance or None if not found
"""
for model in self.models:
if model.name == name:
return model
return None
def get_model_by_id(self, model_id: str) -> Optional[UserModel]:
"""
Get a model by its unique ID
Args:
model_id: UUID of the model
Returns:
UserModel instance or None if not found
"""
for model in self.models:
if model.id == model_id:
return model
return None
def list_models(self) -> List[UserModel]:
"""
List all models
Returns:
List of all UserModel instances
"""
return self.models.copy()
def find_by_path(self, path: str) -> Optional[UserModel]:
"""
Find a model by its path
Args:
path: Model directory path
Returns:
UserModel instance or None if not found
"""
# Normalize paths for comparison
search_path = str(Path(path).resolve())
for model in self.models:
model_path = str(Path(model.path).resolve())
if model_path == search_path:
return model
return None
def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool:
"""
Check if a name conflicts with existing models
Args:
name: Name to check
exclude_name: Optional name to exclude from check (for rename operations)
Returns:
True if conflict exists, False otherwise
"""
for model in self.models:
if model.name == name and model.name != exclude_name:
return True
return False
def refresh_status(self) -> Dict[str, List[str]]:
"""
Check all models and identify missing ones
Returns:
Dictionary with 'valid' and 'missing' lists of model names
"""
valid = []
missing = []
for model in self.models:
if model.path_exists():
valid.append(model.name)
else:
missing.append(model.name)
return {"valid": valid, "missing": missing}
def get_model_count(self) -> int:
"""Get total number of registered models"""
return len(self.models)
def suggest_name(self, base_name: str) -> str:
"""
Suggest a unique name based on base_name
Args:
base_name: Base name to derive from
Returns:
A unique name (may have suffix like -2, -3 etc.)
"""
if not self.check_name_conflict(base_name):
return base_name
counter = 2
while True:
candidate = f"{base_name}-{counter}"
if not self.check_name_conflict(candidate):
return candidate
counter += 1