mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-19 22:09:10 +00:00
239
kt-kernel/python/cli/commands/quant.py
Normal file
239
kt-kernel/python/cli/commands/quant.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Quant command for kt-cli.
|
||||
|
||||
Quantizes model weights for CPU inference.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
create_progress,
|
||||
print_error,
|
||||
print_info,
|
||||
print_step,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info
|
||||
|
||||
|
||||
class QuantMethod(str, Enum):
|
||||
"""Quantization method."""
|
||||
|
||||
INT4 = "int4"
|
||||
INT8 = "int8"
|
||||
|
||||
|
||||
def quant(
|
||||
model: str = typer.Argument(
|
||||
...,
|
||||
help="Model name or path to quantize",
|
||||
),
|
||||
method: QuantMethod = typer.Option(
|
||||
QuantMethod.INT4,
|
||||
"--method",
|
||||
"-m",
|
||||
help="Quantization method",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output path for quantized weights",
|
||||
),
|
||||
input_type: str = typer.Option(
|
||||
"fp8",
|
||||
"--input-type",
|
||||
"-i",
|
||||
help="Input weight type (fp8, fp16, bf16)",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU threads for quantization",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
no_merge: bool = typer.Option(
|
||||
False,
|
||||
"--no-merge",
|
||||
help="Don't merge safetensor files",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Quantize model weights for CPU inference."""
|
||||
settings = get_settings()
|
||||
console.print()
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
|
||||
# Detect CPU configuration
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
print_info(f"CPU threads: {final_cpu_threads}")
|
||||
print_info(f"NUMA nodes: {final_numa_nodes}")
|
||||
|
||||
# Check if output exists
|
||||
if output.exists():
|
||||
print_warning(f"Output path already exists: {output}")
|
||||
if not yes:
|
||||
if not confirm("Overwrite?", default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm
|
||||
if not yes:
|
||||
console.print()
|
||||
console.print("[bold]Quantization Settings:[/bold]")
|
||||
console.print(f" Input: {input_path}")
|
||||
console.print(f" Output: {output}")
|
||||
console.print(f" Method: {method.value.upper()}")
|
||||
console.print(f" Input type: {input_type}")
|
||||
console.print()
|
||||
print_warning("Quantization may take 30-60 minutes depending on model size.")
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Find conversion script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found. Install with: kt install inference")
|
||||
raise typer.Exit(1)
|
||||
|
||||
script_path = kt_kernel_path / "scripts" / "convert_cpu_weights.py"
|
||||
if not script_path.exists():
|
||||
print_error(f"Conversion script not found: {script_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable, str(script_path),
|
||||
"--input-path", str(input_path),
|
||||
"--input-type", input_type,
|
||||
"--output", str(output),
|
||||
"--quant-method", method.value,
|
||||
"--cpuinfer-threads", str(final_cpu_threads),
|
||||
"--threadpool-count", str(final_numa_nodes),
|
||||
]
|
||||
|
||||
if no_merge:
|
||||
cmd.append("--no-merge-safetensor")
|
||||
|
||||
# Run quantization
|
||||
console.print()
|
||||
print_step(t("quant_starting"))
|
||||
console.print()
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("quant_complete"))
|
||||
console.print()
|
||||
console.print(f" Quantized weights saved to: {output}")
|
||||
console.print()
|
||||
console.print(" Use with:")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
else:
|
||||
print_error(f"Quantization failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print_error(f"Failed to run quantization: {e}")
|
||||
raise typer.Exit(1)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
print_warning("Quantization interrupted.")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
def _resolve_input_path(model: str, settings) -> Optional[Path]:
|
||||
"""Resolve the input model path."""
|
||||
# Check if it's already a path
|
||||
path = Path(model)
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
# Search in models directory
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
matches = registry.search(model)
|
||||
|
||||
if matches:
|
||||
model_info = matches[0]
|
||||
# Try to find in all configured model directories
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
for models_dir in model_paths:
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
for p in possible_paths:
|
||||
if p.exists() and (p / "config.json").exists():
|
||||
return p
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
|
||||
Path.cwd().parent / "kt-kernel",
|
||||
Path.cwd() / "kt-kernel",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "scripts").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user