mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-15 02:47:22 +00:00
@@ -4,6 +4,8 @@ Doctor command for kt-cli.
|
||||
Diagnoses environment issues and provides recommendations.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -29,6 +31,67 @@ from kt_kernel.cli.utils.environment import (
|
||||
)
|
||||
|
||||
|
||||
def _get_kt_kernel_info() -> dict:
|
||||
"""Get kt-kernel installation information."""
|
||||
info = {
|
||||
"installed": False,
|
||||
"version": None,
|
||||
"cpu_variant": None,
|
||||
"install_path": None,
|
||||
"available_variants": [],
|
||||
"extension_file": None,
|
||||
}
|
||||
|
||||
try:
|
||||
import kt_kernel
|
||||
|
||||
info["installed"] = True
|
||||
info["version"] = getattr(kt_kernel, "__version__", "unknown")
|
||||
info["cpu_variant"] = getattr(kt_kernel, "__cpu_variant__", "unknown")
|
||||
|
||||
# Get installation path
|
||||
info["install_path"] = os.path.dirname(kt_kernel.__file__)
|
||||
|
||||
# Find available .so files
|
||||
kt_kernel_dir = info["install_path"]
|
||||
so_files = glob.glob(os.path.join(kt_kernel_dir, "_kt_kernel_ext_*.so"))
|
||||
so_files.extend(glob.glob(os.path.join(kt_kernel_dir, "kt_kernel_ext*.so")))
|
||||
|
||||
# Parse variant names from filenames
|
||||
variants = set()
|
||||
for so_file in so_files:
|
||||
basename = os.path.basename(so_file)
|
||||
if "_kt_kernel_ext_" in basename:
|
||||
# Extract variant from _kt_kernel_ext_amx.cpython-311-x86_64-linux-gnu.so
|
||||
parts = basename.split("_")
|
||||
if len(parts) >= 4:
|
||||
variant = parts[3] # "amx" from "_kt_kernel_ext_amx..."
|
||||
if variant.startswith("avx"):
|
||||
# Normalize avx variants
|
||||
if variant in ["avx512", "avx512_bf16", "avx512_vbmi", "avx512_vnni", "avx512_base"]:
|
||||
variants.add("avx512")
|
||||
else:
|
||||
variants.add(variant)
|
||||
else:
|
||||
variants.add(variant)
|
||||
elif "kt_kernel_ext" in basename:
|
||||
variants.add("default")
|
||||
|
||||
info["available_variants"] = sorted(list(variants))
|
||||
|
||||
# Get current extension file
|
||||
if hasattr(kt_kernel, "kt_kernel_ext"):
|
||||
ext_module = kt_kernel.kt_kernel_ext
|
||||
info["extension_file"] = getattr(ext_module, "__file__", None)
|
||||
|
||||
except ImportError:
|
||||
info["installed"] = False
|
||||
except Exception as e:
|
||||
info["error"] = str(e)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def doctor(
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed diagnostics"),
|
||||
) -> None:
|
||||
@@ -157,6 +220,76 @@ def doctor(
|
||||
}
|
||||
)
|
||||
|
||||
# 6b. kt-kernel installation check
|
||||
kt_info = _get_kt_kernel_info()
|
||||
|
||||
if kt_info["installed"]:
|
||||
# Build display string for kt-kernel
|
||||
variant = kt_info["cpu_variant"]
|
||||
version = kt_info["version"]
|
||||
available_variants = kt_info["available_variants"]
|
||||
|
||||
# Determine status based on CPU variant
|
||||
if variant == "amx":
|
||||
kt_status = "ok"
|
||||
kt_hint = "AMX variant loaded - optimal performance"
|
||||
elif variant.startswith("avx512"):
|
||||
kt_status = "ok"
|
||||
kt_hint = "AVX512 variant loaded - good performance"
|
||||
elif variant == "avx2":
|
||||
kt_status = "warning"
|
||||
kt_hint = "AVX2 variant - consider upgrading CPU for AMX/AVX512"
|
||||
else:
|
||||
kt_status = "warning"
|
||||
kt_hint = f"Unknown variant: {variant}"
|
||||
|
||||
kt_value = f"v{version} ({variant.upper()})"
|
||||
if verbose and available_variants:
|
||||
kt_value += f" [dim] - available: {', '.join(available_variants)}[/dim]"
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "kt-kernel",
|
||||
"status": kt_status,
|
||||
"value": kt_value,
|
||||
"hint": kt_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# Show extension file path in verbose mode
|
||||
if verbose and kt_info.get("extension_file"):
|
||||
ext_file = os.path.basename(kt_info["extension_file"])
|
||||
checks.append(
|
||||
{
|
||||
"name": " └─ Extension",
|
||||
"status": "ok",
|
||||
"value": ext_file,
|
||||
"hint": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Show installation path in verbose mode
|
||||
if verbose and kt_info.get("install_path"):
|
||||
checks.append(
|
||||
{
|
||||
"name": " └─ Path",
|
||||
"status": "ok",
|
||||
"value": kt_info["install_path"],
|
||||
"hint": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_msg = kt_info.get("error", "Not installed")
|
||||
checks.append(
|
||||
{
|
||||
"name": "kt-kernel",
|
||||
"status": "error",
|
||||
"value": error_msg,
|
||||
"hint": "kt-kernel is required - run: pip install kt-kernel",
|
||||
}
|
||||
)
|
||||
issues_found = True
|
||||
|
||||
# 7. System memory (with frequency if available)
|
||||
mem_info = detect_memory_info()
|
||||
if mem_info.frequency_mhz and mem_info.type:
|
||||
@@ -204,7 +337,6 @@ def doctor(
|
||||
# 6. Required packages
|
||||
packages = [
|
||||
("kt-kernel", ">=0.4.0", False), # name, version_req, required
|
||||
("ktransformers", ">=0.4.0", False),
|
||||
("sglang", ">=0.4.0", False),
|
||||
("torch", ">=2.4.0", True),
|
||||
("transformers", ">=4.45.0", True),
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
@@ -30,128 +31,163 @@ from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect
|
||||
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
|
||||
|
||||
|
||||
@click.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True},
|
||||
add_help_option=False, # We'll handle help manually to avoid conflicts
|
||||
)
|
||||
@click.argument("model", required=False, default=None)
|
||||
@click.option("--host", "-H", default=None, help="Server host address")
|
||||
@click.option("--port", "-p", type=int, default=None, help="Server port")
|
||||
@click.option("--gpu-experts", type=int, default=None, help="Number of GPU experts per layer")
|
||||
@click.option("--cpu-threads", type=int, default=None, help="Number of CPU inference threads")
|
||||
@click.option("--numa-nodes", type=int, default=None, help="Number of NUMA nodes")
|
||||
@click.option(
|
||||
"--tensor-parallel-size", "--tp", "tensor_parallel_size", type=int, default=None, help="Tensor parallel size"
|
||||
)
|
||||
@click.option("--model-path", type=click.Path(), default=None, help="Custom model path")
|
||||
@click.option("--weights-path", type=click.Path(), default=None, help="Custom quantized weights path")
|
||||
@click.option("--kt-method", default=None, help="KT quantization method")
|
||||
@click.option(
|
||||
"--kt-gpu-prefill-threshold", "kt_gpu_prefill_threshold", type=int, default=None, help="GPU prefill token threshold"
|
||||
)
|
||||
@click.option("--attention-backend", default=None, help="Attention backend")
|
||||
@click.option("--max-total-tokens", "max_total_tokens", type=int, default=None, help="Maximum total tokens")
|
||||
@click.option("--max-running-requests", "max_running_requests", type=int, default=None, help="Maximum running requests")
|
||||
@click.option("--chunked-prefill-size", "chunked_prefill_size", type=int, default=None, help="Chunked prefill size")
|
||||
@click.option("--mem-fraction-static", "mem_fraction_static", type=float, default=None, help="Memory fraction static")
|
||||
@click.option("--watchdog-timeout", "watchdog_timeout", type=int, default=None, help="Watchdog timeout")
|
||||
@click.option("--served-model-name", "served_model_name", default=None, help="Served model name")
|
||||
@click.option(
|
||||
"--disable-shared-experts-fusion",
|
||||
"disable_shared_experts_fusion",
|
||||
is_flag=True,
|
||||
default=None,
|
||||
help="Disable shared experts fusion",
|
||||
)
|
||||
@click.option(
|
||||
"--enable-shared-experts-fusion",
|
||||
"enable_shared_experts_fusion",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Enable shared experts fusion",
|
||||
)
|
||||
@click.option("--quantize", "-q", is_flag=True, default=False, help="Quantize model")
|
||||
@click.option("--advanced", is_flag=True, default=False, help="Show advanced options")
|
||||
@click.option("--dry-run", "dry_run", is_flag=True, default=False, help="Show command without executing")
|
||||
@click.pass_context
|
||||
def run(
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or path (e.g., deepseek-v3, qwen3-30b). If not specified, shows interactive selection.",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"--host",
|
||||
"-H",
|
||||
help="Server host address",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
None,
|
||||
"--port",
|
||||
"-p",
|
||||
help="Server port",
|
||||
),
|
||||
# CPU/GPU configuration
|
||||
gpu_experts: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--gpu-experts",
|
||||
help="Number of GPU experts per layer",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU inference threads (kt-cpuinfer, defaults to 80% of CPU cores)",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
tensor_parallel_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--tensor-parallel-size",
|
||||
"--tp",
|
||||
help="Tensor parallel size (number of GPUs)",
|
||||
),
|
||||
# Model paths
|
||||
model_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--model-path",
|
||||
help="Custom model path",
|
||||
),
|
||||
weights_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--weights-path",
|
||||
help="Custom quantized weights path",
|
||||
),
|
||||
# KT-kernel options
|
||||
kt_method: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--kt-method",
|
||||
help="KT quantization method (AMXINT4, RAWFP8, etc.)",
|
||||
),
|
||||
kt_gpu_prefill_token_threshold: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--kt-gpu-prefill-threshold",
|
||||
help="GPU prefill token threshold for kt-kernel",
|
||||
),
|
||||
# SGLang options
|
||||
attention_backend: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--attention-backend",
|
||||
help="Attention backend (triton, flashinfer)",
|
||||
),
|
||||
max_total_tokens: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-total-tokens",
|
||||
help="Maximum total tokens",
|
||||
),
|
||||
max_running_requests: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-running-requests",
|
||||
help="Maximum running requests",
|
||||
),
|
||||
chunked_prefill_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--chunked-prefill-size",
|
||||
help="Chunked prefill size",
|
||||
),
|
||||
mem_fraction_static: Optional[float] = typer.Option(
|
||||
None,
|
||||
"--mem-fraction-static",
|
||||
help="Memory fraction for static allocation",
|
||||
),
|
||||
watchdog_timeout: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--watchdog-timeout",
|
||||
help="Watchdog timeout in seconds",
|
||||
),
|
||||
served_model_name: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--served-model-name",
|
||||
help="Custom model name for API responses",
|
||||
),
|
||||
# Performance flags
|
||||
disable_shared_experts_fusion: Optional[bool] = typer.Option(
|
||||
None,
|
||||
"--disable-shared-experts-fusion/--enable-shared-experts-fusion",
|
||||
help="Disable/enable shared experts fusion",
|
||||
),
|
||||
# Other options
|
||||
quantize: bool = typer.Option(
|
||||
False,
|
||||
"--quantize",
|
||||
"-q",
|
||||
help="Quantize model if weights not found",
|
||||
),
|
||||
advanced: bool = typer.Option(
|
||||
False,
|
||||
"--advanced",
|
||||
help="Show advanced options",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
help="Show command without executing",
|
||||
),
|
||||
ctx: click.Context,
|
||||
model: Optional[str],
|
||||
host: Optional[str],
|
||||
port: Optional[int],
|
||||
gpu_experts: Optional[int],
|
||||
cpu_threads: Optional[int],
|
||||
numa_nodes: Optional[int],
|
||||
tensor_parallel_size: Optional[int],
|
||||
model_path: Optional[str],
|
||||
weights_path: Optional[str],
|
||||
kt_method: Optional[str],
|
||||
kt_gpu_prefill_threshold: Optional[int],
|
||||
attention_backend: Optional[str],
|
||||
max_total_tokens: Optional[int],
|
||||
max_running_requests: Optional[int],
|
||||
chunked_prefill_size: Optional[int],
|
||||
mem_fraction_static: Optional[float],
|
||||
watchdog_timeout: Optional[int],
|
||||
served_model_name: Optional[str],
|
||||
disable_shared_experts_fusion: Optional[bool],
|
||||
enable_shared_experts_fusion: bool,
|
||||
quantize: bool,
|
||||
advanced: bool,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
"""Start model inference server."""
|
||||
"""Start model inference server.
|
||||
|
||||
\b
|
||||
Examples: kt run deepseek-v3 | kt run m2 --tensor-parallel-size 2 | kt run /path/to/model --gpu-experts 4
|
||||
|
||||
\b
|
||||
Custom Options: Pass any SGLang server option directly (e.g., kt run m2 --fp8-gemm-backend triton).
|
||||
Common: --fp8-gemm-backend, --tool-call-parser, --reasoning-parser, --dp-size, --enable-ma
|
||||
For full list: python -m sglang.launch_server --help
|
||||
"""
|
||||
# Handle --help manually since we disabled it
|
||||
# Check sys.argv for --help or -h since ctx.args may not be set yet
|
||||
if "--help" in sys.argv or "-h" in sys.argv:
|
||||
click.echo(ctx.get_help())
|
||||
return
|
||||
|
||||
# Handle disable/enable shared experts fusion flags
|
||||
if enable_shared_experts_fusion:
|
||||
disable_shared_experts_fusion = False
|
||||
elif disable_shared_experts_fusion is None:
|
||||
disable_shared_experts_fusion = None
|
||||
|
||||
# Convert Path objects from click
|
||||
model_path_obj = Path(model_path) if model_path else None
|
||||
weights_path_obj = Path(weights_path) if weights_path else None
|
||||
|
||||
# Get extra args that weren't parsed (unknown options)
|
||||
# click stores these in ctx.args when ignore_unknown_options=True
|
||||
extra_cli_args = list(ctx.args) if ctx.args else []
|
||||
|
||||
# Remove --help from extra args if present (already handled)
|
||||
extra_cli_args = [arg for arg in extra_cli_args if arg not in ["--help", "-h"]]
|
||||
|
||||
# Call the actual run function implementation
|
||||
_run_impl(
|
||||
model=model,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_experts=gpu_experts,
|
||||
cpu_threads=cpu_threads,
|
||||
numa_nodes=numa_nodes,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
model_path=model_path_obj,
|
||||
weights_path=weights_path_obj,
|
||||
kt_method=kt_method,
|
||||
kt_gpu_prefill_threshold=kt_gpu_prefill_threshold,
|
||||
attention_backend=attention_backend,
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_running_requests=max_running_requests,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
watchdog_timeout=watchdog_timeout,
|
||||
served_model_name=served_model_name,
|
||||
disable_shared_experts_fusion=disable_shared_experts_fusion,
|
||||
quantize=quantize,
|
||||
advanced=advanced,
|
||||
dry_run=dry_run,
|
||||
extra_cli_args=extra_cli_args,
|
||||
)
|
||||
|
||||
|
||||
def _run_impl(
|
||||
model: Optional[str],
|
||||
host: Optional[str],
|
||||
port: Optional[int],
|
||||
gpu_experts: Optional[int],
|
||||
cpu_threads: Optional[int],
|
||||
numa_nodes: Optional[int],
|
||||
tensor_parallel_size: Optional[int],
|
||||
model_path: Optional[Path],
|
||||
weights_path: Optional[Path],
|
||||
kt_method: Optional[str],
|
||||
kt_gpu_prefill_threshold: Optional[int],
|
||||
attention_backend: Optional[str],
|
||||
max_total_tokens: Optional[int],
|
||||
max_running_requests: Optional[int],
|
||||
chunked_prefill_size: Optional[int],
|
||||
mem_fraction_static: Optional[float],
|
||||
watchdog_timeout: Optional[int],
|
||||
served_model_name: Optional[str],
|
||||
disable_shared_experts_fusion: Optional[bool],
|
||||
quantize: bool,
|
||||
advanced: bool,
|
||||
dry_run: bool,
|
||||
extra_cli_args: list[str],
|
||||
) -> None:
|
||||
"""Actual implementation of run command."""
|
||||
# Check if SGLang is installed before proceeding
|
||||
from kt_kernel.cli.utils.sglang_checker import (
|
||||
check_sglang_installation,
|
||||
@@ -387,7 +423,7 @@ def run(
|
||||
# KT-kernel options
|
||||
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = (
|
||||
kt_gpu_prefill_token_threshold
|
||||
kt_gpu_prefill_threshold
|
||||
or model_defaults.get("kt-gpu-prefill-token-threshold")
|
||||
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
)
|
||||
@@ -456,6 +492,7 @@ def run(
|
||||
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
|
||||
settings=settings,
|
||||
extra_model_params=extra_params,
|
||||
extra_cli_args=extra_cli_args,
|
||||
)
|
||||
|
||||
# Prepare environment variables
|
||||
@@ -535,29 +572,51 @@ def run(
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_model_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths."""
|
||||
def _find_model_path(model_info: ModelInfo, settings, max_depth: int = 3) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths.
|
||||
|
||||
Args:
|
||||
model_info: Model information to search for
|
||||
settings: Settings instance
|
||||
max_depth: Maximum depth to search within each model path (default: 3)
|
||||
|
||||
Returns:
|
||||
Path to the model directory, or None if not found
|
||||
"""
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Generate possible names to search for
|
||||
possible_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.name.replace(" ", "-"),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
# Add alias-based names
|
||||
for alias in model_info.aliases:
|
||||
possible_names.append(alias)
|
||||
possible_names.append(alias.lower())
|
||||
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.name.replace(" ", "-"),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
models_dir / model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Add alias-based paths
|
||||
for alias in model_info.aliases:
|
||||
possible_paths.append(models_dir / alias)
|
||||
possible_paths.append(models_dir / alias.lower())
|
||||
# Search recursively up to max_depth
|
||||
for depth in range(max_depth):
|
||||
for name in possible_names:
|
||||
if depth == 0:
|
||||
# Direct children: models_dir / name
|
||||
search_paths = [models_dir / name]
|
||||
else:
|
||||
# Nested: use rglob to find directories matching the name
|
||||
search_paths = list(models_dir.rglob(name))
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
for path in search_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
@@ -613,6 +672,7 @@ def _build_sglang_command(
|
||||
disable_shared_experts_fusion: bool,
|
||||
settings,
|
||||
extra_model_params: Optional[dict] = None, # New parameter for additional params
|
||||
extra_cli_args: Optional[list[str]] = None, # Extra args from CLI to pass to sglang
|
||||
) -> list[str]:
|
||||
"""Build the SGLang launch command."""
|
||||
cmd = [
|
||||
@@ -734,6 +794,10 @@ def _build_sglang_command(
|
||||
if extra_args:
|
||||
cmd.extend(extra_args)
|
||||
|
||||
# Add extra CLI args (user-provided options not defined in kt CLI)
|
||||
if extra_cli_args:
|
||||
cmd.extend(extra_cli_args)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
|
||||
@@ -68,7 +68,8 @@ def _update_help_texts() -> None:
|
||||
|
||||
# Register commands
|
||||
app.command(name="version", help="Show version information")(version.version)
|
||||
app.command(name="run", help="Start model inference server")(run.run)
|
||||
# Run command is handled specially in main() to allow extra args
|
||||
# (not registered here to avoid typer's argument parsing)
|
||||
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
|
||||
app.command(name="quant", help="Quantize model weights")(quant.quant)
|
||||
app.command(name="bench", help="Run full benchmark")(bench.bench)
|
||||
@@ -429,6 +430,15 @@ def main():
|
||||
if should_check_first_run and args:
|
||||
check_first_run()
|
||||
|
||||
# Handle "run" command specially to pass through unknown options
|
||||
if args and args[0] == "run":
|
||||
# Get args after "run"
|
||||
run_args = args[1:]
|
||||
# Use click command directly with ignore_unknown_options
|
||||
from kt_kernel.cli.commands import run as run_module
|
||||
|
||||
sys.exit(run_module.run.main(args=run_args, standalone_mode=False))
|
||||
|
||||
app()
|
||||
|
||||
|
||||
|
||||
@@ -280,9 +280,12 @@ class ModelRegistry:
|
||||
"""List all registered models."""
|
||||
return list(self._models.values())
|
||||
|
||||
def find_local_models(self) -> list[tuple[ModelInfo, Path]]:
|
||||
def find_local_models(self, max_depth: int = 3) -> list[tuple[ModelInfo, Path]]:
|
||||
"""Find models that are downloaded locally in any configured model path.
|
||||
|
||||
Args:
|
||||
max_depth: Maximum depth to search within each model path (default: 3)
|
||||
|
||||
Returns:
|
||||
List of (ModelInfo, path) tuples for local models
|
||||
"""
|
||||
@@ -297,18 +300,40 @@ class ModelRegistry:
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model.name,
|
||||
models_dir / model.name.lower(),
|
||||
models_dir / model.hf_repo.split("/")[-1],
|
||||
models_dir / model.hf_repo.replace("/", "--"),
|
||||
# Generate possible names to search for
|
||||
possible_names = [
|
||||
model.name,
|
||||
model.name.lower(),
|
||||
model.hf_repo.split("/")[-1],
|
||||
model.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
results.append((model, path))
|
||||
found = True
|
||||
# Search recursively up to max_depth
|
||||
for depth in range(max_depth):
|
||||
# Build glob pattern for current depth
|
||||
# depth=0: direct children, depth=1: grandchildren, etc.
|
||||
glob_pattern = "*" if depth > 0 else ""
|
||||
for _ in range(depth):
|
||||
glob_pattern = "*/" + glob_pattern if glob_pattern else "*"
|
||||
|
||||
for name in possible_names:
|
||||
if depth == 0:
|
||||
# Direct children: models_dir / name
|
||||
search_paths = [models_dir / name]
|
||||
else:
|
||||
# Nested: use rglob to find directories matching the name
|
||||
search_paths = list(models_dir.rglob(name))
|
||||
|
||||
for path in search_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
results.append((model, path))
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
|
||||
if found:
|
||||
|
||||
Reference in New Issue
Block a user