Files
ktransformers/kt-kernel/python/cli/utils/environment.py
ErvinXie d8046e1bb4 Kt minimax (#1742)
[feat]: fp8 kernel and kt-cli support
2025-12-24 15:39:44 +08:00

1109 lines
35 KiB
Python

"""
Environment detection utilities for kt-cli.
Provides functions to detect:
- Virtual environment managers (conda, venv, uv, mamba)
- Python version and packages
- CUDA and GPU information
- System resources (CPU, RAM, disk)
"""
import os
import platform
import shutil
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
@dataclass
class EnvManager:
"""Information about an environment manager."""
name: str
version: str
path: str
@dataclass
class GPUInfo:
"""Information about a GPU."""
index: int
name: str
vram_gb: float
cuda_capability: Optional[str] = None
@dataclass
class CPUInfo:
"""Information about the CPU."""
name: str
cores: int
threads: int
numa_nodes: int
instruction_sets: list[str] = field(default_factory=list) # AVX, AVX2, AVX512, AMX, etc.
numa_info: dict = field(default_factory=dict) # node -> cpus mapping
@dataclass
class MemoryInfo:
"""Information about system memory."""
total_gb: float
available_gb: float
frequency_mhz: Optional[int] = None
channels: Optional[int] = None
type: Optional[str] = None # DDR4, DDR5, etc.
@dataclass
class SystemInfo:
"""Complete system information."""
python_version: str
platform: str
cuda_version: Optional[str]
gpus: list[GPUInfo]
cpu: CPUInfo
ram_gb: float
env_managers: list[EnvManager]
def run_command(cmd: list[str], timeout: int = 10) -> Optional[str]:
"""Run a command and return its output, or None if it fails."""
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False)
if result.returncode == 0:
return result.stdout.strip()
return None
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return None
def detect_env_managers() -> list[EnvManager]:
"""Detect available virtual environment managers."""
managers = []
# Check conda
conda_path = shutil.which("conda")
if conda_path:
version = run_command(["conda", "--version"])
if version:
# "conda 24.1.0" -> "24.1.0"
version = version.split()[-1] if version else "unknown"
managers.append(EnvManager(name="conda", version=version, path=conda_path))
# Check mamba
mamba_path = shutil.which("mamba")
if mamba_path:
version = run_command(["mamba", "--version"])
if version:
# First line: "mamba 1.5.0"
version = version.split("\n")[0].split()[-1] if version else "unknown"
managers.append(EnvManager(name="mamba", version=version, path=mamba_path))
# Check uv
uv_path = shutil.which("uv")
if uv_path:
version = run_command(["uv", "--version"])
if version:
# "uv 0.5.0" -> "0.5.0"
version = version.split()[-1] if version else "unknown"
managers.append(EnvManager(name="uv", version=version, path=uv_path))
# Check if venv is available (built into Python)
try:
import venv # noqa: F401
managers.append(EnvManager(name="venv", version="builtin", path="python -m venv"))
except ImportError:
pass
return managers
def check_docker() -> Optional[EnvManager]:
"""Check if Docker is available."""
docker_path = shutil.which("docker")
if docker_path:
version = run_command(["docker", "--version"])
if version:
# "Docker version 24.0.7, build afdd53b"
parts = version.split()
version = parts[2].rstrip(",") if len(parts) > 2 else "unknown"
return EnvManager(name="docker", version=version, path=docker_path)
return None
def check_kt_env_exists(manager: str, env_name: str = "kt") -> bool:
"""Check if a kt environment exists for the given manager."""
if manager == "conda" or manager == "mamba":
result = run_command([manager, "env", "list"])
if result:
# Check if env_name appears as a separate word in the output
for line in result.split("\n"):
parts = line.split()
if parts and parts[0] == env_name:
return True
elif manager == "uv":
# uv uses .venv in the project directory or ~/.local/share/uv/envs/
venv_path = Path.home() / ".local" / "share" / "uv" / "envs" / env_name
if venv_path.exists():
return True
# Also check current directory
if Path(env_name).exists() and (Path(env_name) / "bin" / "python").exists():
return True
elif manager == "venv":
# Check common locations
venv_path = Path.home() / ".virtualenvs" / env_name
if venv_path.exists():
return True
if Path(env_name).exists() and (Path(env_name) / "bin" / "python").exists():
return True
return False
def get_kt_env_path(manager: str, env_name: str = "kt") -> Optional[Path]:
"""Get the path to the kt environment."""
if manager == "conda" or manager == "mamba":
result = run_command([manager, "env", "list"])
if result:
for line in result.split("\n"):
parts = line.split()
if parts and parts[0] == env_name:
# The path is the last part
return Path(parts[-1])
elif manager == "uv":
venv_path = Path.home() / ".local" / "share" / "uv" / "envs" / env_name
if venv_path.exists():
return venv_path
elif manager == "venv":
venv_path = Path.home() / ".virtualenvs" / env_name
if venv_path.exists():
return venv_path
return None
def detect_cuda_version() -> Optional[str]:
"""Detect CUDA version from nvidia-smi or nvcc."""
# Try nvidia-smi first
nvidia_smi = run_command(["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"])
if nvidia_smi:
# Get CUDA version from nvidia-smi
full_output = run_command(["nvidia-smi"])
if full_output:
for line in full_output.split("\n"):
if "CUDA Version:" in line:
# "| CUDA Version: 12.1 |"
parts = line.split("CUDA Version:")
if len(parts) > 1:
version = parts[1].strip().split()[0]
return version
# Try nvcc
nvcc_output = run_command(["nvcc", "--version"])
if nvcc_output:
for line in nvcc_output.split("\n"):
if "release" in line.lower():
# "Cuda compilation tools, release 12.1, V12.1.105"
parts = line.split("release")
if len(parts) > 1:
version = parts[1].strip().split(",")[0].strip()
return version
return None
def detect_gpus() -> list[GPUInfo]:
"""Detect available NVIDIA GPUs, respecting CUDA_VISIBLE_DEVICES."""
gpus = []
nvidia_smi = run_command(["nvidia-smi", "--query-gpu=index,name,memory.total", "--format=csv,noheader,nounits"])
if nvidia_smi:
for line in nvidia_smi.strip().split("\n"):
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 3:
try:
index = int(parts[0])
name = parts[1]
vram_mb = float(parts[2])
vram_gb = round(vram_mb / 1024, 1)
gpus.append(GPUInfo(index=index, name=name, vram_gb=vram_gb))
except (ValueError, IndexError):
continue
# Filter by CUDA_VISIBLE_DEVICES if set
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible is not None:
if cuda_visible == "":
# Empty string means no GPUs visible
return []
try:
# Parse CUDA_VISIBLE_DEVICES (can be "0,1,2" or "0-3" etc.)
visible_indices = _parse_cuda_visible_devices(cuda_visible)
# Filter GPUs to only those in CUDA_VISIBLE_DEVICES
filtered_gpus = [gpu for gpu in gpus if gpu.index in visible_indices]
# Re-index GPUs to match CUDA's logical indexing (0, 1, 2, ...)
for i, gpu in enumerate(filtered_gpus):
# Keep original index in a comment, but CUDA sees them as 0,1,2...
gpu.index = i
return filtered_gpus
except ValueError:
# If parsing fails, return all GPUs as fallback
pass
return gpus
def _parse_cuda_visible_devices(cuda_visible: str) -> list[int]:
"""Parse CUDA_VISIBLE_DEVICES string into list of GPU indices.
Supports formats like:
- "0,1,2,3" -> [0, 1, 2, 3]
- "0-3" -> [0, 1, 2, 3]
- "0,2-4,7" -> [0, 2, 3, 4, 7]
"""
indices = []
parts = cuda_visible.split(",")
for part in parts:
part = part.strip()
if "-" in part:
# Range like "0-3"
start, end = part.split("-")
indices.extend(range(int(start), int(end) + 1))
else:
# Single index
indices.append(int(part))
return sorted(set(indices)) # Remove duplicates and sort
def detect_cpu_info() -> CPUInfo:
"""Detect CPU information including instruction sets and NUMA topology."""
name = "Unknown"
cores = os.cpu_count() or 1
threads = cores
numa_nodes = 1
instruction_sets: list[str] = []
numa_info: dict[str, list[int]] = {}
if platform.system() == "Linux":
try:
with open("/proc/cpuinfo", "r") as f:
content = f.read()
# Get CPU name
for line in content.split("\n"):
if line.startswith("model name"):
name = line.split(":")[1].strip()
break
# Get physical cores vs threads
cpu_cores = content.count("processor\t:")
if cpu_cores > 0:
threads = cpu_cores
siblings = None
cores_per = None
for line in content.split("\n"):
if "siblings" in line:
siblings = int(line.split(":")[1].strip())
if "cpu cores" in line:
cores_per = int(line.split(":")[1].strip())
if siblings and cores_per:
cores = threads // (siblings // cores_per) if siblings > cores_per else threads
# Get instruction sets from flags
for line in content.split("\n"):
if line.startswith("flags"):
flags = line.split(":")[1].strip().split()
instruction_sets = _parse_cpu_flags(flags)
break
except (OSError, IOError, ValueError):
pass
# Get NUMA topology
numa_path = Path("/sys/devices/system/node")
if numa_path.exists():
numa_dirs = [d for d in numa_path.iterdir() if d.name.startswith("node")]
numa_nodes = len(numa_dirs)
for node_dir in numa_dirs:
node_name = node_dir.name # e.g., "node0"
cpulist_path = node_dir / "cpulist"
if cpulist_path.exists():
try:
cpulist = cpulist_path.read_text().strip()
numa_info[node_name] = _parse_cpu_list(cpulist)
except (OSError, IOError):
pass
elif platform.system() == "Darwin":
# macOS
name_output = run_command(["sysctl", "-n", "machdep.cpu.brand_string"])
if name_output:
name = name_output.strip()
cores_output = run_command(["sysctl", "-n", "hw.physicalcpu"])
if cores_output:
cores = int(cores_output.strip())
threads_output = run_command(["sysctl", "-n", "hw.logicalcpu"])
if threads_output:
threads = int(threads_output.strip())
# Get instruction sets on macOS
features_output = run_command(["sysctl", "-n", "machdep.cpu.features"])
if features_output:
flags = features_output.lower().split()
instruction_sets = _parse_cpu_flags(flags)
return CPUInfo(
name=name,
cores=cores,
threads=threads,
numa_nodes=numa_nodes,
instruction_sets=instruction_sets,
numa_info=numa_info,
)
def _parse_cpu_flags(flags: list[str]) -> list[str]:
"""Parse CPU flags to extract relevant instruction sets for KTransformers."""
# Instruction sets important for KTransformers/kt-kernel
relevant_instructions = {
# Basic SIMD
"sse": "SSE",
"sse2": "SSE2",
"sse3": "SSE3",
"ssse3": "SSSE3",
"sse4_1": "SSE4.1",
"sse4_2": "SSE4.2",
# AVX family
"avx": "AVX",
"avx2": "AVX2",
"avx512f": "AVX512F",
"avx512bw": "AVX512BW",
"avx512vl": "AVX512VL",
"avx512dq": "AVX512DQ",
"avx512cd": "AVX512CD",
"avx512vnni": "AVX512VNNI",
"avx512_bf16": "AVX512BF16",
"avx512_fp16": "AVX512FP16",
"avx_vnni": "AVX-VNNI",
# AMX (Advanced Matrix Extensions) - Intel
"amx_tile": "AMX-TILE",
"amx_bf16": "AMX-BF16",
"amx_int8": "AMX-INT8",
"amx_fp16": "AMX-FP16",
# Other relevant
"fma": "FMA",
"f16c": "F16C",
"bmi1": "BMI1",
"bmi2": "BMI2",
}
found = []
flags_lower = {f.lower() for f in flags}
for flag, display_name in relevant_instructions.items():
if flag in flags_lower:
found.append(display_name)
# Sort by importance for display
priority = [
"AMX-INT8",
"AMX-BF16",
"AMX-FP16",
"AMX-TILE",
"AVX512BF16",
"AVX512VNNI",
"AVX512F",
"AVX512BW",
"AVX512VL",
"AVX2",
"AVX",
"FMA",
"SSE4.2",
]
result = []
for p in priority:
if p in found:
result.append(p)
found.remove(p)
result.extend(sorted(found)) # Add remaining
return result
def _parse_cpu_list(cpulist: str) -> list[int]:
"""Parse CPU list string like '0-3,8-11' to list of CPU IDs."""
cpus = []
for part in cpulist.split(","):
if "-" in part:
start, end = part.split("-")
cpus.extend(range(int(start), int(end) + 1))
else:
cpus.append(int(part))
return cpus
def detect_memory_info() -> MemoryInfo:
"""Detect detailed memory information including frequency and type."""
total_gb = detect_ram_gb()
available_gb = detect_available_ram_gb()
frequency_mhz: Optional[int] = None
channels: Optional[int] = None
mem_type: Optional[str] = None
if platform.system() == "Linux":
# Try dmidecode without sudo first (may work if user has permissions)
dmidecode_output = run_command(["dmidecode", "-t", "memory"])
if dmidecode_output:
frequency_mhz, mem_type, channels = _parse_dmidecode_memory(dmidecode_output)
# Fallback: try to read from /sys or /proc
if frequency_mhz is None:
frequency_mhz = _detect_memory_frequency_sysfs()
elif platform.system() == "Darwin":
# macOS - use system_profiler
mem_output = run_command(["system_profiler", "SPMemoryDataType"])
if mem_output:
frequency_mhz, mem_type = _parse_macos_memory(mem_output)
return MemoryInfo(
total_gb=total_gb,
available_gb=available_gb,
frequency_mhz=frequency_mhz,
channels=channels,
type=mem_type,
)
def _parse_dmidecode_memory(output: str) -> tuple[Optional[int], Optional[str], Optional[int]]:
"""Parse dmidecode memory output."""
frequency_mhz: Optional[int] = None
mem_type: Optional[str] = None
dimm_count = 0
for line in output.split("\n"):
line = line.strip()
if line.startswith("Speed:") and "MHz" in line:
try:
# "Speed: 4800 MHz" or "Speed: 4800 MT/s"
parts = line.split(":")[1].strip().split()
freq = int(parts[0])
if freq > 0 and (frequency_mhz is None or freq > frequency_mhz):
frequency_mhz = freq
except (ValueError, IndexError):
pass
elif line.startswith("Type:") and mem_type is None:
type_val = line.split(":")[1].strip()
if type_val and type_val != "Unknown":
mem_type = type_val
elif line.startswith("Size:") and "MB" in line or "GB" in line:
dimm_count += 1
return frequency_mhz, mem_type, dimm_count if dimm_count > 0 else None
def _detect_memory_frequency_sysfs() -> Optional[int]:
"""Try to detect memory frequency from sysfs."""
# This is a fallback and may not work on all systems
try:
# Try reading from edac
edac_path = Path("/sys/devices/system/edac/mc")
if edac_path.exists():
for mc_dir in edac_path.iterdir():
freq_file = mc_dir / "mc_config"
if freq_file.exists():
content = freq_file.read_text()
# Parse for frequency information
# Format varies by system
pass
except (OSError, IOError):
pass
return None
def _parse_macos_memory(output: str) -> tuple[Optional[int], Optional[str]]:
"""Parse macOS system_profiler memory output."""
frequency_mhz: Optional[int] = None
mem_type: Optional[str] = None
for line in output.split("\n"):
line = line.strip()
if "Speed:" in line:
try:
parts = line.split(":")[1].strip().split()
frequency_mhz = int(parts[0])
except (ValueError, IndexError):
pass
elif "Type:" in line:
mem_type = line.split(":")[1].strip()
return frequency_mhz, mem_type
def detect_ram_gb() -> float:
"""Detect total system RAM in GB."""
if platform.system() == "Linux":
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if line.startswith("MemTotal:"):
# "MemTotal: 32780516 kB"
kb = int(line.split()[1])
return round(kb / 1024 / 1024, 1)
except (OSError, IOError, ValueError):
pass
elif platform.system() == "Darwin":
mem_output = run_command(["sysctl", "-n", "hw.memsize"])
if mem_output:
return round(int(mem_output) / 1024 / 1024 / 1024, 1)
# Fallback
try:
import psutil
return round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 1)
except ImportError:
return 0.0
def detect_available_ram_gb() -> float:
"""Detect available system RAM in GB."""
if platform.system() == "Linux":
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if line.startswith("MemAvailable:"):
kb = int(line.split()[1])
return round(kb / 1024 / 1024, 1)
except (OSError, IOError, ValueError):
pass
# Fallback
try:
import psutil
return round(psutil.virtual_memory().available / 1024 / 1024 / 1024, 1)
except ImportError:
return 0.0
def detect_disk_space_gb(path: str = "/") -> tuple[float, float]:
"""Detect disk space (available, total) in GB for the given path."""
try:
import shutil
total, used, free = shutil.disk_usage(path)
return round(free / 1024 / 1024 / 1024, 1), round(total / 1024 / 1024 / 1024, 1)
except (OSError, IOError):
return 0.0, 0.0
def get_installed_package_version(package_name: str) -> Optional[str]:
"""Get the version of an installed Python package."""
try:
from importlib.metadata import version
return version(package_name)
except Exception:
return None
def get_system_info() -> SystemInfo:
"""Gather complete system information."""
return SystemInfo(
python_version=platform.python_version(),
platform=f"{platform.system()} {platform.release()}",
cuda_version=detect_cuda_version(),
gpus=detect_gpus(),
cpu=detect_cpu_info(),
ram_gb=detect_ram_gb(),
env_managers=detect_env_managers(),
)
def is_in_virtual_env() -> bool:
"""Check if currently running inside a virtual environment."""
return (
hasattr(sys, "real_prefix")
or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)
or os.environ.get("VIRTUAL_ENV") is not None
or os.environ.get("CONDA_PREFIX") is not None
)
def get_current_env_name() -> Optional[str]:
"""Get the name of the current virtual environment."""
if os.environ.get("CONDA_DEFAULT_ENV"):
return os.environ["CONDA_DEFAULT_ENV"]
if os.environ.get("VIRTUAL_ENV"):
return Path(os.environ["VIRTUAL_ENV"]).name
return None
# Import sys for is_in_virtual_env
import sys # noqa: E402
@dataclass
class StorageLocation:
"""Information about a storage location."""
path: str
available_gb: float
total_gb: float
is_writable: bool
mount_point: str
def scan_storage_locations(min_size_gb: float = 50.0) -> list[StorageLocation]:
"""
Scan system for potential model storage locations.
Looks for:
- Large mounted filesystems (> min_size_gb)
- Common model storage paths
- User home directory
Args:
min_size_gb: Minimum available space in GB to consider
Returns:
List of StorageLocation sorted by available space (descending)
"""
locations: dict[str, StorageLocation] = {} # Use dict to deduplicate by path
# Get all mount points from /proc/mounts (Linux)
mount_points = _get_mount_points()
for mount_point in mount_points:
try:
available_gb, total_gb = detect_disk_space_gb(mount_point)
# Skip small or pseudo filesystems
if total_gb < 10:
continue
# Check if writable
is_writable = os.access(mount_point, os.W_OK)
# Create potential model paths under this mount
potential_paths = _get_potential_model_paths(mount_point)
for path in potential_paths:
if path in locations:
continue
# Get actual available space for this path
path_available, path_total = detect_disk_space_gb(path)
if path_available >= min_size_gb:
path_writable = os.access(path, os.W_OK) if os.path.exists(path) else is_writable
locations[path] = StorageLocation(
path=path,
available_gb=path_available,
total_gb=path_total,
is_writable=path_writable,
mount_point=mount_point,
)
except (OSError, IOError):
continue
# Also check common model storage locations
common_paths = [
str(Path.home() / ".ktransformers" / "models"),
str(Path.home() / "models"),
str(Path.home() / ".cache" / "huggingface"),
"/data/models",
"/models",
"/opt/models",
]
for path in common_paths:
if path in locations:
continue
try:
# Check if parent exists for paths that don't exist yet
check_path = path
while not os.path.exists(check_path) and check_path != "/":
check_path = str(Path(check_path).parent)
if os.path.exists(check_path):
available_gb, total_gb = detect_disk_space_gb(check_path)
if available_gb >= min_size_gb:
is_writable = os.access(check_path, os.W_OK)
locations[path] = StorageLocation(
path=path,
available_gb=available_gb,
total_gb=total_gb,
is_writable=is_writable,
mount_point=check_path,
)
except (OSError, IOError):
continue
# Sort by available space descending, then by path
sorted_locations = sorted(locations.values(), key=lambda x: (-x.available_gb, x.path))
# Filter to only writable locations
return [loc for loc in sorted_locations if loc.is_writable]
def _get_mount_points() -> list[str]:
"""Get all mount points on the system."""
mount_points = []
if platform.system() == "Linux":
try:
with open("/proc/mounts", "r") as f:
for line in f:
parts = line.split()
if len(parts) >= 2:
mount_point = parts[1]
fs_type = parts[2] if len(parts) > 2 else ""
# Skip pseudo filesystems
skip_fs = {
"proc",
"sysfs",
"devpts",
"tmpfs",
"cgroup",
"cgroup2",
"pstore",
"securityfs",
"debugfs",
"hugetlbfs",
"mqueue",
"fusectl",
"configfs",
"devtmpfs",
"efivarfs",
"autofs",
"binfmt_misc",
"overlay",
"nsfs",
"tracefs",
}
if fs_type in skip_fs:
continue
# Skip paths that are clearly system paths
skip_prefixes = ("/sys", "/proc", "/dev", "/run/user")
if any(mount_point.startswith(p) for p in skip_prefixes):
continue
mount_points.append(mount_point)
except (OSError, IOError):
pass
# Always include home and root
mount_points.extend([str(Path.home()), "/"])
# Deduplicate while preserving order
seen = set()
unique_mounts = []
for mp in mount_points:
if mp not in seen:
seen.add(mp)
unique_mounts.append(mp)
return unique_mounts
def _get_potential_model_paths(mount_point: str) -> list[str]:
"""Get potential model storage paths under a mount point."""
paths = []
# The mount point itself (for dedicated data drives)
if mount_point not in ("/", "/home"):
paths.append(mount_point)
paths.append(os.path.join(mount_point, "models"))
# If it's under home, suggest standard locations
home = str(Path.home())
if mount_point == home or mount_point == "/home":
paths.append(os.path.join(home, ".ktransformers", "models"))
paths.append(os.path.join(home, "models"))
# For root mount, suggest /data or /opt
if mount_point == "/":
paths.extend(["/data/models", "/opt/models"])
# Check for common data directories on this mount
for subdir in ["data", "models", "ai", "llm", "huggingface"]:
potential = os.path.join(mount_point, subdir)
if os.path.exists(potential) and os.path.isdir(potential):
paths.append(potential)
return paths
def format_size_gb(size_gb: float) -> str:
"""Format size in GB to human readable string."""
if size_gb >= 1000:
return f"{size_gb / 1000:.1f}TB"
return f"{size_gb:.1f}GB"
@dataclass
class LocalModel:
"""Information about a locally detected model."""
name: str
path: str
size_gb: float
model_type: str # "huggingface", "gguf", "safetensors"
has_config: bool
file_count: int
def scan_local_models(search_paths: list[str], max_depth: int = 3) -> list[LocalModel]:
"""
Scan directories for locally downloaded models.
Looks for:
- Directories with config.json (HuggingFace format)
- Directories with .safetensors files
- Directories with .gguf files
Args:
search_paths: List of paths to search
max_depth: Maximum directory depth to search
Returns:
List of LocalModel sorted by size (descending)
"""
models: dict[str, LocalModel] = {} # Use path as key to deduplicate
for search_path in search_paths:
if not os.path.exists(search_path):
continue
_scan_directory_for_models(search_path, models, current_depth=0, max_depth=max_depth)
# Sort by size descending
return sorted(models.values(), key=lambda x: -x.size_gb)
def _scan_directory_for_models(
directory: str, models: dict[str, LocalModel], current_depth: int, max_depth: int
) -> None:
"""Recursively scan a directory for models."""
if current_depth > max_depth:
return
try:
entries = list(os.scandir(directory))
except (PermissionError, OSError):
return
# Check if this directory is a model
model = _detect_model_in_directory(directory, entries)
if model:
models[model.path] = model
return # Don't scan subdirectories of a model
# Scan subdirectories
for entry in entries:
if entry.is_dir() and not entry.name.startswith("."):
_scan_directory_for_models(entry.path, models, current_depth + 1, max_depth)
def _detect_model_in_directory(directory: str, entries: list) -> Optional[LocalModel]:
"""Detect if a directory contains a model."""
entry_names = {e.name for e in entries}
has_config = "config.json" in entry_names
safetensor_files = [e for e in entries if e.name.endswith(".safetensors") and e.is_file()]
gguf_files = [e for e in entries if e.name.endswith(".gguf") and e.is_file()]
# Determine model type
model_type = None
if has_config and safetensor_files:
model_type = "huggingface"
elif gguf_files:
model_type = "gguf"
elif safetensor_files:
model_type = "safetensors"
elif has_config:
# Config but no weights - might be incomplete
# Check for other model-related files
model_files = {
"model.safetensors.index.json",
"pytorch_model.bin.index.json",
"model.safetensors",
"pytorch_model.bin",
}
if entry_names & model_files:
model_type = "huggingface"
if not model_type:
return None
# Calculate directory size
size_bytes = _get_directory_size(directory)
size_gb = size_bytes / (1024**3)
# Skip very small directories (likely incomplete or config-only)
if size_gb < 0.1:
return None
# Get model name from directory name
name = os.path.basename(directory)
# Count model files
file_count = len(safetensor_files) + len(gguf_files)
if not file_count:
# Count .bin files as fallback
file_count = len([e for e in entries if e.name.endswith(".bin") and e.is_file()])
return LocalModel(
name=name,
path=directory,
size_gb=round(size_gb, 1),
model_type=model_type,
has_config=has_config,
file_count=file_count,
)
def _get_directory_size(directory: str) -> int:
"""Get total size of a directory in bytes."""
total_size = 0
try:
for entry in os.scandir(directory):
try:
if entry.is_file(follow_symlinks=False):
total_size += entry.stat().st_size
elif entry.is_dir(follow_symlinks=False):
total_size += _get_directory_size(entry.path)
except (PermissionError, OSError):
continue
except (PermissionError, OSError):
pass
return total_size
def scan_models_in_location(location: StorageLocation, max_depth: int = 2) -> list[LocalModel]:
"""Scan a storage location for models."""
search_paths = [location.path]
# Also check common subdirectories
for subdir in ["models", "huggingface", "hub", ".cache/huggingface/hub"]:
subpath = os.path.join(location.path, subdir)
if os.path.exists(subpath):
search_paths.append(subpath)
return scan_local_models(search_paths, max_depth=max_depth)
@dataclass
class CPUBuildFeatures:
"""CPU features for build configuration."""
has_amx: bool
has_avx512: bool
has_avx512_vnni: bool
has_avx512_bf16: bool
has_avx2: bool
recommended_instruct: str # NATIVE, AVX512, AVX2
recommended_amx: bool
def detect_cpu_build_features() -> CPUBuildFeatures:
"""
Detect CPU features for build configuration.
This is used to auto-configure kt-kernel source builds.
Reads /proc/cpuinfo on Linux to detect instruction set support.
Returns:
CPUBuildFeatures with detection results
"""
has_amx = False
has_avx512 = False
has_avx512_vnni = False
has_avx512_bf16 = False
has_avx2 = False
if platform.system() == "Linux":
try:
with open("/proc/cpuinfo", "r") as f:
content = f.read()
# Get flags from first processor
for line in content.split("\n"):
if line.startswith("flags"):
flags = line.split(":")[1].strip().split()
flags_lower = {f.lower() for f in flags}
# Check for AMX support (requires all three)
if {"amx_tile", "amx_int8", "amx_bf16"} <= flags_lower:
has_amx = True
# Check for AVX512 support
if "avx512f" in flags_lower:
has_avx512 = True
# Check for AVX512 VNNI
if "avx512_vnni" in flags_lower or "avx512vnni" in flags_lower:
has_avx512_vnni = True
# Check for AVX512 BF16
if "avx512_bf16" in flags_lower or "avx512bf16" in flags_lower:
has_avx512_bf16 = True
# Check for AVX2
if "avx2" in flags_lower:
has_avx2 = True
break
except (OSError, IOError):
pass
elif platform.system() == "Darwin":
# macOS - use sysctl
features_output = run_command(["sysctl", "-n", "machdep.cpu.features"])
if features_output:
flags_lower = {f.lower() for f in features_output.split()}
has_avx2 = "avx2" in flags_lower
# macOS doesn't have AMX or AVX512 typically
# Determine recommended configuration
if has_amx:
recommended_instruct = "NATIVE"
recommended_amx = True
elif has_avx512:
recommended_instruct = "NATIVE"
recommended_amx = False
elif has_avx2:
recommended_instruct = "NATIVE"
recommended_amx = False
else:
recommended_instruct = "AVX2"
recommended_amx = False
return CPUBuildFeatures(
has_amx=has_amx,
has_avx512=has_avx512,
has_avx512_vnni=has_avx512_vnni,
has_avx512_bf16=has_avx512_bf16,
has_avx2=has_avx2,
recommended_instruct=recommended_instruct,
recommended_amx=recommended_amx,
)