Files
ktransformers/kt-kernel/python/cli/utils/model_verifier.py
Oql 56cbd69ac4 kt-cli enhancement (#1834)
* [feat]: redesign kt run interactive configuration with i18n support

- Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port)
- Add configuration save/load system (~/.ktransformers/run_configs.yaml)
- Add i18n support for kt chat (en/zh translations)
- Add universal input validators with auto-retry and Chinese comma support
- Add port availability checker with auto-suggestion
- Add parser configuration (--tool-call-parser, --reasoning-parser)
- Remove tuna command and clean up redundant files
- Fix: variable reference bug in run.py, filter to show only MoE models

* [feat]: unify model selection UI and enable shared experts fusion by default

- Unify kt run model selection table with kt model list display
  * Add Total size, MoE Size, Repo, and SHA256 status columns
  * Use consistent formatting and styling
  * Improve user decision-making with more information

- Enable --disable-shared-experts-fusion by default
  * Change default value from False to True
  * Users can still override with --enable-shared-experts-fusion

* [feat]: improve kt chat with performance metrics and better CJK support

- Add performance metrics display after each response
  * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token)
  * Accurate input/output token counts using model tokenizer
  * Fallback to estimation if tokenizer unavailable
  * Metrics shown in dim style (not prominent)

- Fix Chinese character input issues
  * Replace Prompt.ask() with console.input() for better CJK support
  * Fixes backspace deletion showing half-characters

- Suppress NumPy subnormal warnings
  * Filter "The value of the smallest subnormal" warnings
  * Cleaner CLI output on certain hardware environments

* [fix]: correct TTFT measurement in kt chat

- Move start_time initialization before API call
- Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms
- Now correctly measures time from request sent to first token received

* [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案

* [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力

* [docs]: 添加 Clawdbot 飞书接入教程链接

* [feat]: improve CLI table display, model verification, and chat experience

- Add sequence number (#) column to all model tables by default
- Filter kt edit to show only MoE GPU models (exclude AMX)
- Extend kt model verify to check *.json and *.py files in addition to weights
- Fix re-verification bug where repaired files caused false failures
- Suppress tokenizer debug output in kt chat token counting

* [fix]: fix cpu cores.

---------

Co-authored-by: skqliao <skqliao@gmail.com>
2026-02-04 16:44:54 +08:00

919 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Model Verifier
SHA256 verification for model integrity
"""
import hashlib
import requests
import os
from pathlib import Path
from typing import Dict, Any, Literal, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
def _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:
"""
Compute SHA256 for a single file (worker function for multiprocessing).
Args:
file_path: Path to the file
Returns:
Tuple of (filename, sha256_hash, file_size_mb)
"""
sha256_hash = hashlib.sha256()
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Read file in chunks to handle large files
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(8192 * 1024), b""): # 8MB chunks
sha256_hash.update(byte_block)
return file_path.name, sha256_hash.hexdigest(), file_size_mb
def check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:
"""
Check if HuggingFace is accessible.
Args:
timeout: Connection timeout in seconds
Returns:
Tuple of (is_accessible, message)
"""
test_url = "https://huggingface.co"
try:
response = requests.head(test_url, timeout=timeout, allow_redirects=True)
if response.status_code < 500: # 2xx, 3xx, 4xx are all considered "accessible"
return True, "HuggingFace is accessible"
except requests.exceptions.Timeout:
return False, f"Connection to {test_url} timed out"
except requests.exceptions.ConnectionError:
return False, f"Cannot connect to {test_url}"
except requests.exceptions.RequestException as e:
return False, f"Connection error: {str(e)}"
return False, "Unknown connection error"
def verify_model_integrity(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
) -> Dict[str, Any]:
"""
Verify local model integrity against remote repository SHA256 hashes.
Verifies all important files:
- *.safetensors (weights)
- *.json (config files)
- *.py (custom model code)
Args:
repo_type: Type of repository ("huggingface" or "modelscope")
repo_id: Repository ID (e.g., "deepseek-ai/DeepSeek-V3")
local_dir: Local directory containing model files
progress_callback: Optional callback function(message: str) for progress updates
Returns:
Dictionary with verification results:
{
"status": "passed" | "failed" | "error",
"files_checked": int,
"files_passed": int,
"files_failed": [list of filenames],
"error_message": str (optional)
}
"""
def report_progress(msg: str):
"""Helper to report progress"""
if progress_callback:
progress_callback(msg)
try:
# Convert repo_type to platform format
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes from remote
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform)
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes with progress
report_progress(f"Calculating SHA256 for local files...")
# Get all local files matching the patterns
local_files = []
for pattern in ["*.safetensors", "*.json", "*.py"]:
local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])
if not local_files:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No verifiable files found in local directory: {local_dir}",
}
# Calculate hashes for all files
local_hashes = calculate_local_sha256(
local_dir,
file_pattern="*.safetensors", # Unused when files_list is provided
progress_callback=report_progress,
files_list=local_files,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes with progress
report_progress(f"Comparing {len(official_hashes)} files...")
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
# Handle potential path separators in filename
file_basename = Path(filename).name
# Try to find the file in local hashes
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING")
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
report_progress(f" [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH")
else:
files_passed += 1
report_progress(f" [{idx}/{len(official_hashes)}] ✓ {file_basename}")
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except ImportError as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope",
"is_network_error": False,
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
) as e:
# Network-related errors - suggest mirror
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def calculate_local_sha256(
local_dir: Path, file_pattern: str = "*.safetensors", progress_callback=None, files_list: list[Path] = None
) -> Dict[str, str]:
"""
Calculate SHA256 hashes for files in a directory using parallel processing.
Args:
local_dir: Directory to scan
file_pattern: Glob pattern for files to hash (ignored if files_list is provided)
progress_callback: Optional callback function(message: str) for progress updates
files_list: Optional pre-filtered list of files to hash (overrides file_pattern)
Returns:
Dictionary mapping filename to SHA256 hash
"""
result = {}
if not local_dir.exists():
return result
# Get all files first to report total
if files_list is not None:
files_to_hash = files_list
else:
files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]
total_files = len(files_to_hash)
if total_files == 0:
return result
# Use min(16, total_files) workers to avoid over-spawning processes
max_workers = min(16, total_files)
if progress_callback:
progress_callback(f" Using {max_workers} parallel workers for SHA256 calculation")
# Use ProcessPoolExecutor for CPU-intensive SHA256 computation
completed_count = 0
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}
# Process results as they complete
for future in as_completed(future_to_file):
completed_count += 1
try:
filename, sha256_hash, file_size_mb = future.result()
result[filename] = sha256_hash
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)")
except Exception as e:
file_path = future_to_file[future]
if progress_callback:
progress_callback(f" [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}")
return result
def fetch_model_sha256(
repo_id: str,
platform: Literal["hf", "ms"],
revision: str | None = None,
use_mirror: bool = False,
timeout: int | None = None,
) -> dict[str, str]:
"""
获取模型仓库中所有重要文件的 sha256 哈希值。
包括:
- *.safetensors (权重文件)
- *.json (配置文件config.json, tokenizer_config.json 等)
- *.py (自定义模型代码modeling.py, configuration.py 等)
Args:
repo_id: 仓库 ID例如 "Qwen/Qwen3-30B-A3B"
platform: 平台,"hf" (HuggingFace) 或 "ms" (ModelScope)
revision: 版本/分支,默认 HuggingFace 为 "main"ModelScope 为 "master"
use_mirror: 是否使用镜像(仅对 HuggingFace 有效)
timeout: 网络请求超时时间None 表示不设置超时
Returns:
dict: 文件名到 sha256 的映射,例如 {"model-00001-of-00016.safetensors": "abc123...", "config.json": "def456..."}
"""
if platform == "hf":
# 先尝试直连,失败后自动使用镜像
try:
if use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=False, timeout=timeout)
except Exception as e:
# 如果不是镜像模式且失败了,自动重试使用镜像
if not use_mirror:
return _fetch_from_huggingface(repo_id, revision or "main", use_mirror=True, timeout=timeout)
else:
raise e
elif platform == "ms":
return _fetch_from_modelscope(repo_id, revision or "master", timeout=timeout)
else:
raise ValueError(f"不支持的平台: {platform},请使用 'hf''ms'")
def _fetch_from_huggingface(
repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None
) -> dict[str, str]:
"""从 HuggingFace 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
use_mirror: 是否使用镜像hf-mirror.com
timeout: 网络请求超时时间None 表示不设置超时
"""
import os
import socket
# 如果需要使用镜像,设置环境变量
original_endpoint = os.environ.get("HF_ENDPOINT")
if use_mirror and not original_endpoint:
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
from huggingface_hub import HfApi, list_repo_files
try:
api = HfApi()
all_files = list_repo_files(repo_id=repo_id, revision=revision)
# 筛选重要文件:*.safetensors, *.json, *.py
important_files = [f for f in all_files if f.endswith((".safetensors", ".json", ".py"))]
if not important_files:
return {}
paths_info = api.get_paths_info(
repo_id=repo_id,
paths=important_files,
revision=revision,
)
result = {}
for file_info in paths_info:
if hasattr(file_info, "lfs") and file_info.lfs is not None:
sha256 = file_info.lfs.sha256
else:
sha256 = getattr(file_info, "blob_id", None)
result[file_info.path] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
# 恢复原始环境变量
if use_mirror and not original_endpoint:
os.environ.pop("HF_ENDPOINT", None)
elif original_endpoint:
os.environ["HF_ENDPOINT"] = original_endpoint
def _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:
"""从 ModelScope 获取所有重要文件的 sha256
Args:
repo_id: 仓库 ID
revision: 版本/分支
timeout: 网络请求超时时间None 表示不设置超时
"""
import socket
from modelscope.hub.api import HubApi
# Set socket timeout if specified
original_timeout = socket.getdefaulttimeout()
if timeout is not None:
socket.setdefaulttimeout(timeout)
try:
api = HubApi()
files_info = api.get_model_files(model_id=repo_id, revision=revision)
result = {}
for file_info in files_info:
filename = file_info.get("Name", file_info.get("Path", ""))
# 筛选重要文件:*.safetensors, *.json, *.py
if filename.endswith((".safetensors", ".json", ".py")):
sha256 = file_info.get("Sha256", file_info.get("sha256", None))
result[filename] = sha256
return result
finally:
# 恢复原始 socket timeout
socket.setdefaulttimeout(original_timeout)
def verify_model_integrity_with_progress(
repo_type: Literal["huggingface", "modelscope"],
repo_id: str,
local_dir: Path,
progress_callback=None,
verbose: bool = False,
use_mirror: bool = False,
files_to_verify: list[str] | None = None,
timeout: int | None = None,
) -> Dict[str, Any]:
"""
Verify model integrity with enhanced progress reporting for Rich Progress bars.
This is a wrapper around verify_model_integrity() that provides more detailed
progress information suitable for progress bar display.
The progress_callback receives:
- (message: str, total: int, current: int) for countable operations
- (message: str) for status updates
Args:
repo_type: Repository type ("huggingface" or "modelscope")
repo_id: Repository ID
local_dir: Local directory path
progress_callback: Optional callback for progress updates
verbose: If True, output detailed SHA256 comparison for each file
use_mirror: If True, use HuggingFace mirror (hf-mirror.com)
files_to_verify: Optional list of specific files to verify (for re-verification)
timeout: Network request timeout in seconds (None = no timeout)
"""
def report_progress(msg: str, total=None, current=None):
"""Enhanced progress reporter"""
if progress_callback:
progress_callback(msg, total, current)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
# 1. Fetch official SHA256 hashes
if files_to_verify:
report_progress(f"Fetching SHA256 hashes for {len(files_to_verify)} files...")
elif use_mirror and platform == "hf":
report_progress("Fetching official SHA256 hashes from mirror (hf-mirror.com)...")
else:
report_progress("Fetching official SHA256 hashes from remote repository...")
official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)
# Filter to only requested files if specified
if files_to_verify:
# Extract clean filenames from files_to_verify (remove markers like "(missing)")
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Filter official_hashes to only include requested files
# Compare using basename since official_hashes keys might have paths
official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}
report_progress(f"✓ Fetched {len(official_hashes)} file hashes from remote")
if not official_hashes:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"No safetensors files found in remote repository: {repo_id}",
}
# 2. Calculate local SHA256 hashes
local_dir_path = Path(local_dir)
# Only hash the files we need to verify
if files_to_verify:
# Extract clean filenames (without markers)
clean_filenames = set()
for f in files_to_verify:
clean_f = f.replace(" (missing)", "").replace(" (hash mismatch)", "").strip()
# Ensure we only use the filename, not full path
clean_filenames.add(Path(clean_f).name)
# Only hash files that match the clean filenames
files_to_hash = [
f for f in local_dir_path.glob("*.safetensors") if f.is_file() and f.name in clean_filenames
]
else:
files_to_hash = [f for f in local_dir_path.glob("*.safetensors") if f.is_file()]
total_files = len(files_to_hash)
if files_to_verify:
report_progress(f"Calculating SHA256 for {total_files} repaired files...", total=total_files, current=0)
else:
report_progress(f"Calculating SHA256 for local files...", total=total_files, current=0)
# Progress wrapper for hashing
completed_count = [0] # Use list for mutable closure
def hash_progress_callback(msg: str):
if "Using" in msg and "workers" in msg:
report_progress(msg)
elif "[" in msg and "/" in msg and "]" in msg:
# Progress update like: [1/10] ✓ filename (123.4 MB)
completed_count[0] += 1
report_progress(msg, total=total_files, current=completed_count[0])
# Pass the pre-filtered files_to_hash list
local_hashes = calculate_local_sha256(
local_dir_path,
"*.safetensors",
progress_callback=hash_progress_callback,
files_list=files_to_hash if files_to_verify else None,
)
report_progress(f"✓ Calculated {len(local_hashes)} local file hashes")
# 3. Compare hashes
report_progress(f"Comparing {len(official_hashes)} files...", total=len(official_hashes), current=0)
files_failed = []
files_missing = []
files_passed = 0
for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):
file_basename = Path(filename).name
# Find matching local file
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\n Remote: {official_hash}\n Local: <missing>",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)",
total=len(official_hashes),
current=idx,
)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)",
total=len(official_hashes),
current=idx,
)
else:
files_passed += 1
if verbose:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}\n Remote: {official_hash}\n Local: {local_hash}",
total=len(official_hashes),
current=idx,
)
else:
report_progress(
f"[{idx}/{len(official_hashes)}] ✓ {file_basename}", total=len(official_hashes), current=idx
)
# 4. Return results
total_checked = len(official_hashes)
if files_failed or files_missing:
all_failed = files_failed + [f"{f} (missing)" for f in files_missing]
return {
"status": "failed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": all_failed,
"error_message": f"{len(all_failed)} file(s) failed verification",
}
else:
return {
"status": "passed",
"files_checked": total_checked,
"files_passed": files_passed,
"files_failed": [],
}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
TimeoutError, # Socket timeout from socket.setdefaulttimeout()
OSError, # Network-related OS errors
) as e:
error_msg = f"Network error: {str(e)}"
if repo_type == "huggingface":
error_msg += "\n\nTry using HuggingFace mirror:\n export HF_ENDPOINT=https://hf-mirror.com"
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": error_msg,
"is_network_error": True,
}
except Exception as e:
return {
"status": "error",
"files_checked": 0,
"files_passed": 0,
"files_failed": [],
"error_message": f"Verification failed: {str(e)}",
"is_network_error": False,
}
def pre_operation_verification(user_model, user_registry, operation_name: str = "operation") -> None:
"""Pre-operation verification of model integrity.
Can be used before running or quantizing models to ensure integrity.
Args:
user_model: UserModel object to verify
user_registry: UserModelRegistry instance
operation_name: Name of the operation (e.g., "running", "quantizing")
"""
from rich.prompt import Prompt, Confirm
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from kt_kernel.cli.i18n import get_lang
from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step
import typer
lang = get_lang()
# Check if already verified
if user_model.sha256_status == "passed":
console.print()
print_info("Model integrity already verified ✓")
console.print()
return
# Model not verified yet
console.print()
console.print("[bold yellow]═══ Model Integrity Check ═══[/bold yellow]")
console.print()
# Check if repo_id exists
if not user_model.repo_id:
# No repo_id - ask user to provide one
console.print("[yellow]No repository ID configured for this model.[/yellow]")
console.print()
console.print("To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')")
console.print()
if not Confirm.ask("Would you like to configure repository ID now?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Ask for repo type
console.print()
console.print("Repository type:")
console.print(" [cyan][1][/cyan] HuggingFace")
console.print(" [cyan][2][/cyan] ModelScope")
console.print()
repo_type_choice = Prompt.ask("Select repository type", choices=["1", "2"], default="1")
repo_type = "huggingface" if repo_type_choice == "1" else "modelscope"
# Ask for repo_id
console.print()
repo_id = Prompt.ask("Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)")
# Update model
user_registry.update_model(user_model.name, {"repo_type": repo_type, "repo_id": repo_id})
user_model.repo_type = repo_type
user_model.repo_id = repo_id
console.print()
print_success(f"Repository configured: {repo_type}:{repo_id}")
console.print()
# Now ask if user wants to verify
console.print("[dim]Model integrity verification is a one-time check that ensures your[/dim]")
console.print("[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]")
console.print()
if not Confirm.ask(f"Would you like to verify model integrity before {operation_name}?", default=True):
console.print()
print_warning(f"Skipping verification. Model will be used for {operation_name} without integrity check.")
console.print()
return
# Perform verification
console.print()
print_step("Verifying model integrity...")
console.print()
# Check connectivity first
use_mirror = False
if user_model.repo_type == "huggingface":
with console.status("[dim]Checking HuggingFace connectivity...[/dim]"):
is_accessible, message = check_huggingface_connectivity(timeout=5)
if not is_accessible:
print_warning("HuggingFace Connection Failed")
console.print()
console.print(f" {message}")
console.print()
console.print(" [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]")
console.print()
use_mirror = True
# Fetch remote hashes with timeout
def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):
"""Fetch hashes with timeout."""
executor = ThreadPoolExecutor(max_workers=1)
try:
platform = "hf" if repo_type == "huggingface" else "ms"
future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)
hashes = future.result(timeout=timeout)
executor.shutdown(wait=False)
return (hashes, False)
except (FutureTimeoutError, Exception):
executor.shutdown(wait=False)
return (None, True)
# Try fetching hashes
status = console.status("[dim]Fetching remote hashes...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)
status.stop()
# Handle timeout with fallback
if timed_out and user_model.repo_type == "huggingface" and not use_mirror:
print_warning("HuggingFace Fetch Timeout (10s)")
console.print()
console.print(" [yellow]Trying HuggingFace mirror...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from mirror...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)
status.stop()
if timed_out and user_model.repo_type == "huggingface":
print_warning("HuggingFace Mirror Timeout (10s)")
console.print()
console.print(" [yellow]Fallback to ModelScope...[/yellow]")
console.print()
status = console.status("[dim]Fetching remote hashes from ModelScope...[/dim]")
status.start()
official_hashes, timed_out = fetch_with_timeout("modelscope", user_model.repo_id, False, 10)
status.stop()
if not official_hashes or timed_out:
print_error("Failed to fetch remote hashes (network timeout)")
console.print()
console.print(" [yellow]Unable to verify model integrity due to network issues.[/yellow]")
console.print()
if not Confirm.ask(f"Continue {operation_name} without verification?", default=False):
raise typer.Exit(0)
console.print()
return
console.print(f" [green]✓ Fetched {len(official_hashes)} file hashes[/green]")
console.print()
# Calculate local hashes and compare
local_dir = Path(user_model.path)
files_to_hash = [f for f in local_dir.glob("*.safetensors") if f.is_file()]
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
) as progress:
# Calculate local hashes
task = progress.add_task("[yellow]Calculating local SHA256...", total=len(files_to_hash))
def hash_callback(msg):
if "[" in msg and "/" in msg and "]" in msg and "" in msg:
progress.advance(task)
local_hashes = calculate_local_sha256(local_dir, "*.safetensors", progress_callback=hash_callback)
progress.remove_task(task)
console.print(f" [green]✓ Calculated {len(local_hashes)} local hashes[/green]")
console.print()
# Compare hashes
task = progress.add_task("[blue]Comparing hashes...", total=len(official_hashes))
files_failed = []
files_missing = []
files_passed = 0
for filename, official_hash in official_hashes.items():
file_basename = Path(filename).name
local_hash = None
for local_file, local_hash_value in local_hashes.items():
if Path(local_file).name == file_basename:
local_hash = local_hash_value
break
if local_hash is None:
files_missing.append(filename)
elif local_hash.lower() != official_hash.lower():
files_failed.append(f"{filename} (hash mismatch)")
else:
files_passed += 1
progress.advance(task)
progress.remove_task(task)
console.print()
# Check results
if not files_failed and not files_missing:
# Verification passed
user_registry.update_model(user_model.name, {"sha256_status": "passed"})
print_success("Model integrity verification PASSED ✓")
console.print()
console.print(f" All {files_passed} files verified successfully")
console.print()
else:
# Verification failed
user_registry.update_model(user_model.name, {"sha256_status": "failed"})
print_error(f"Model integrity verification FAILED")
console.print()
console.print(f" ✓ Passed: [green]{files_passed}[/green]")
console.print(f" ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]")
console.print()
if files_missing:
console.print(f" [red]Missing files ({len(files_missing)}):[/red]")
for f in files_missing[:5]:
console.print(f" - {Path(f).name}")
if len(files_missing) > 5:
console.print(f" ... and {len(files_missing) - 5} more")
console.print()
if files_failed:
console.print(f" [red]Hash mismatch ({len(files_failed)}):[/red]")
for f in files_failed[:5]:
console.print(f" - {f}")
if len(files_failed) > 5:
console.print(f" ... and {len(files_failed) - 5} more")
console.print()
console.print("[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]")
console.print()
console.print("This could cause runtime errors or incorrect inference results.")
console.print()
# Ask if user wants to repair
if Confirm.ask("Would you like to repair (re-download) the corrupted files?", default=True):
console.print()
print_info("Please run: [cyan]kt model verify " + user_model.name + "[/cyan]")
console.print()
console.print("The verify command will guide you through the repair process.")
raise typer.Exit(0)
# Ask if user wants to continue anyway
console.print()
if not Confirm.ask(
f"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]", default=False
):
raise typer.Exit(0)
console.print()
print_warning(f"Proceeding with {operation_name} using unverified weights at your own risk...")
console.print()