Files
ErvinXie d8046e1bb4 Kt minimax (#1742)
[feat]: fp8 kernel and kt-cli support
2025-12-24 15:39:44 +08:00

312 lines
9.5 KiB
Python

"""
Configuration management for kt-cli.
Handles reading and writing configuration from ~/.ktransformers/config.yaml
"""
import os
from pathlib import Path
from typing import Any, Optional
import yaml
# Default configuration directory
DEFAULT_CONFIG_DIR = Path.home() / ".ktransformers"
DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml"
DEFAULT_MODELS_DIR = DEFAULT_CONFIG_DIR / "models"
DEFAULT_CACHE_DIR = DEFAULT_CONFIG_DIR / "cache"
# Default configuration values
DEFAULT_CONFIG = {
"general": {
"language": "auto", # auto, en, zh
"color": True,
"verbose": False,
},
"paths": {
"models": str(DEFAULT_MODELS_DIR),
"cache": str(DEFAULT_CACHE_DIR),
"weights": "", # Custom quantized weights path
},
"server": {
"host": "0.0.0.0",
"port": 30000,
},
"inference": {
# Inference parameters are model-specific and should not have defaults
# They will be auto-detected or use model-specific optimizations
# Environment variables (general optimizations)
"env": {
"PYTORCH_ALLOC_CONF": "expandable_segments:True",
"SGLANG_ENABLE_JIT_DEEPGEMM": "0",
},
},
"download": {
"mirror": "", # HuggingFace mirror URL
"resume": True,
"verify": True,
},
"advanced": {
# Environment variables to set when running
"env": {},
# Extra arguments to pass to sglang
"sglang_args": [],
# Extra arguments to pass to llamafactory
"llamafactory_args": [],
},
"dependencies": {
# SGLang installation source configuration
"sglang": {
"source": "github", # "pypi" or "github"
"repo": "https://github.com/kvcache-ai/sglang",
"branch": "main",
},
},
}
class Settings:
"""Configuration manager for kt-cli."""
def __init__(self, config_path: Optional[Path] = None):
"""Initialize settings manager.
Args:
config_path: Path to config file. Defaults to ~/.ktransformers/config.yaml
"""
self.config_path = config_path or DEFAULT_CONFIG_FILE
self.config_dir = self.config_path.parent
self._config: dict[str, Any] = {}
self._load()
def _ensure_dirs(self) -> None:
"""Ensure configuration directories exist."""
self.config_dir.mkdir(parents=True, exist_ok=True)
# Ensure all model paths exist
model_paths = self.get_model_paths()
for path in model_paths:
path.mkdir(parents=True, exist_ok=True)
Path(self.get("paths.cache", DEFAULT_CACHE_DIR)).mkdir(parents=True, exist_ok=True)
def _load(self) -> None:
"""Load configuration from file."""
self._config = self._deep_copy(DEFAULT_CONFIG)
if self.config_path.exists():
try:
with open(self.config_path, "r", encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
self._deep_merge(self._config, user_config)
except (yaml.YAMLError, OSError) as e:
# Log warning but continue with defaults
print(f"Warning: Failed to load config: {e}")
self._ensure_dirs()
def _save(self) -> None:
"""Save configuration to file."""
self._ensure_dirs()
try:
with open(self.config_path, "w", encoding="utf-8") as f:
yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True)
except OSError as e:
raise RuntimeError(f"Failed to save config: {e}")
def _deep_copy(self, obj: Any) -> Any:
"""Create a deep copy of a nested dict."""
if isinstance(obj, dict):
return {k: self._deep_copy(v) for k, v in obj.items()}
if isinstance(obj, list):
return [self._deep_copy(item) for item in obj]
return obj
def _deep_merge(self, base: dict, override: dict) -> None:
"""Deep merge override into base."""
for key, value in override.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge(base[key], value)
else:
base[key] = value
def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value by dot-separated key.
Args:
key: Dot-separated key path (e.g., "server.port")
default: Default value if key not found
Returns:
Configuration value or default
"""
parts = key.split(".")
value = self._config
for part in parts:
if isinstance(value, dict) and part in value:
value = value[part]
else:
return default
return value
def set(self, key: str, value: Any) -> None:
"""Set a configuration value by dot-separated key.
Args:
key: Dot-separated key path (e.g., "server.port")
value: Value to set
"""
parts = key.split(".")
config = self._config
# Navigate to parent
for part in parts[:-1]:
if part not in config:
config[part] = {}
config = config[part]
# Set value
config[parts[-1]] = value
self._save()
def delete(self, key: str) -> bool:
"""Delete a configuration value.
Args:
key: Dot-separated key path
Returns:
True if key was deleted, False if not found
"""
parts = key.split(".")
config = self._config
# Navigate to parent
for part in parts[:-1]:
if part not in config:
return False
config = config[part]
# Delete key
if parts[-1] in config:
del config[parts[-1]]
self._save()
return True
return False
def reset(self) -> None:
"""Reset configuration to defaults."""
self._config = self._deep_copy(DEFAULT_CONFIG)
self._save()
def get_all(self) -> dict[str, Any]:
"""Get all configuration values."""
return self._deep_copy(self._config)
def get_env_vars(self) -> dict[str, str]:
"""Get environment variables to set."""
env_vars = {}
# Get from advanced.env
advanced_env = self.get("advanced.env", {})
if isinstance(advanced_env, dict):
env_vars.update({k: str(v) for k, v in advanced_env.items()})
return env_vars
@property
def models_dir(self) -> Path:
"""Get the primary models directory path (for backward compatibility)."""
paths = self.get_model_paths()
return paths[0] if paths else Path(DEFAULT_MODELS_DIR)
def get_model_paths(self) -> list[Path]:
"""Get all model directory paths.
Returns a list of Path objects. Supports both:
- Single path: paths.models = "/path/to/models"
- Multiple paths: paths.models = ["/path/1", "/path/2"]
"""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
# Handle both string and list
if isinstance(models_config, str):
return [Path(models_config)]
elif isinstance(models_config, list):
return [Path(p) for p in models_config]
else:
return [Path(DEFAULT_MODELS_DIR)]
def add_model_path(self, path: str) -> None:
"""Add a new model path to the configuration."""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
# Convert to list if it's a string
if isinstance(models_config, str):
paths = [models_config]
elif isinstance(models_config, list):
paths = list(models_config)
else:
paths = []
# Add new path if not already present
if path not in paths:
paths.append(path)
self.set("paths.models", paths)
def remove_model_path(self, path: str) -> bool:
"""Remove a model path from the configuration.
Returns True if path was removed, False if not found.
"""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
if isinstance(models_config, str):
# Can't remove if it's a single string
if models_config == path:
# Don't remove the last path
return False
return False
elif isinstance(models_config, list):
if path in models_config:
paths = list(models_config)
paths.remove(path)
# Don't allow removing all paths
if not paths:
return False
self.set("paths.models", paths if len(paths) > 1 else paths[0])
return True
return False
@property
def cache_dir(self) -> Path:
"""Get the cache directory path."""
return Path(self.get("paths.cache", DEFAULT_CACHE_DIR))
@property
def weights_dir(self) -> Optional[Path]:
"""Get the custom weights directory path."""
weights = self.get("paths.weights", "")
return Path(weights) if weights else None
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""Get the global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reset_settings() -> None:
"""Reset the global settings instance."""
global _settings
_settings = None