mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 22:20:03 +00:00
This change solves the basename collision bug by using UNIQUE(file_path) on the unified asset_references table. Key changes: Database: - Migration 0005 merges asset_cache_states and asset_infos into asset_references - AssetReference now contains: cache state fields (file_path, mtime_ns, needs_verify, is_missing, enrichment_level) plus info fields (name, owner_id, preview_id, etc.) - AssetReferenceMeta replaces AssetInfoMeta - AssetReferenceTag replaces AssetInfoTag - UNIQUE constraint on file_path prevents duplicate entries for same file Code: - New unified query module: asset_reference.py (replaces asset_info.py, cache_state.py) - Updated scanner, seeder, and services to use AssetReference - Updated API routes to use reference_id instead of asset_info_id Tests: - All 175 unit tests updated and passing - Integration tests require server environment (not run here) Amp-Thread-ID: https://ampcode.com/threads/T-019c4fe8-9dcb-75ce-bea8-ea786343a581 Co-authored-by: Amp <amp@ampcode.com>
351 lines
12 KiB
Python
351 lines
12 KiB
Python
"""Metadata extraction for asset scanning.
|
|
|
|
Tier 1: Filesystem metadata (zero parsing)
|
|
Tier 2: Safetensors header metadata (fast JSON read only)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
# Supported safetensors extensions
|
|
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
|
|
|
# Maximum safetensors header size to read (8MB)
|
|
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
|
|
|
def _register_custom_mime_types():
|
|
"""Register custom MIME types for model and config files.
|
|
|
|
Called before each use because mimetypes.init() in server.py resets the database.
|
|
Uses a quick check to avoid redundant registrations.
|
|
"""
|
|
# Quick check if already registered (avoids redundant add_type calls)
|
|
test_result, _ = mimetypes.guess_type("test.safetensors")
|
|
if test_result == "application/safetensors":
|
|
return
|
|
|
|
mimetypes.add_type("application/safetensors", ".safetensors")
|
|
mimetypes.add_type("application/safetensors", ".sft")
|
|
mimetypes.add_type("application/pytorch", ".pt")
|
|
mimetypes.add_type("application/pytorch", ".pth")
|
|
mimetypes.add_type("application/pickle", ".ckpt")
|
|
mimetypes.add_type("application/pickle", ".pkl")
|
|
mimetypes.add_type("application/gguf", ".gguf")
|
|
mimetypes.add_type("application/yaml", ".yaml")
|
|
mimetypes.add_type("application/yaml", ".yml")
|
|
|
|
|
|
# Register custom types at module load
|
|
_register_custom_mime_types()
|
|
|
|
|
|
@dataclass
|
|
class ExtractedMetadata:
|
|
"""Metadata extracted from a file during scanning."""
|
|
|
|
# Tier 1: Filesystem (always available)
|
|
filename: str = ""
|
|
file_path: str = "" # Full absolute path to the file
|
|
content_length: int = 0
|
|
content_type: str | None = None
|
|
format: str = "" # file extension without dot
|
|
|
|
# Tier 2: Safetensors header (if available)
|
|
base_model: str | None = None
|
|
trained_words: list[str] | None = None
|
|
air: str | None = None # CivitAI AIR identifier
|
|
has_preview_images: bool = False
|
|
|
|
# Source provenance (populated if embedded in safetensors)
|
|
source_url: str | None = None
|
|
source_arn: str | None = None
|
|
repo_url: str | None = None
|
|
preview_url: str | None = None
|
|
source_hash: str | None = None
|
|
|
|
# HuggingFace specific
|
|
repo_id: str | None = None
|
|
revision: str | None = None
|
|
filepath: str | None = None
|
|
resolve_url: str | None = None
|
|
|
|
def to_user_metadata(self) -> dict[str, Any]:
|
|
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
|
|
data: dict[str, Any] = {
|
|
"filename": self.filename,
|
|
"content_length": self.content_length,
|
|
"format": self.format,
|
|
}
|
|
if self.file_path:
|
|
data["file_path"] = self.file_path
|
|
if self.content_type:
|
|
data["content_type"] = self.content_type
|
|
|
|
# Tier 2 fields
|
|
if self.base_model:
|
|
data["base_model"] = self.base_model
|
|
if self.trained_words:
|
|
data["trained_words"] = self.trained_words
|
|
if self.air:
|
|
data["air"] = self.air
|
|
if self.has_preview_images:
|
|
data["has_preview_images"] = True
|
|
|
|
# Source provenance
|
|
if self.source_url:
|
|
data["source_url"] = self.source_url
|
|
if self.source_arn:
|
|
data["source_arn"] = self.source_arn
|
|
if self.repo_url:
|
|
data["repo_url"] = self.repo_url
|
|
if self.preview_url:
|
|
data["preview_url"] = self.preview_url
|
|
if self.source_hash:
|
|
data["source_hash"] = self.source_hash
|
|
|
|
# HuggingFace
|
|
if self.repo_id:
|
|
data["repo_id"] = self.repo_id
|
|
if self.revision:
|
|
data["revision"] = self.revision
|
|
if self.filepath:
|
|
data["filepath"] = self.filepath
|
|
if self.resolve_url:
|
|
data["resolve_url"] = self.resolve_url
|
|
|
|
return data
|
|
|
|
def to_meta_rows(self, reference_id: str) -> list[dict]:
|
|
"""Convert to asset_reference_meta rows for typed/indexed querying."""
|
|
rows: list[dict] = []
|
|
|
|
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
|
if val:
|
|
rows.append({
|
|
"asset_reference_id": reference_id,
|
|
"key": key,
|
|
"ordinal": ordinal,
|
|
"val_str": val[:2048] if len(val) > 2048 else val,
|
|
"val_num": None,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
})
|
|
|
|
def add_num(key: str, val: int | float | None) -> None:
|
|
if val is not None:
|
|
rows.append({
|
|
"asset_reference_id": reference_id,
|
|
"key": key,
|
|
"ordinal": 0,
|
|
"val_str": None,
|
|
"val_num": val,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
})
|
|
|
|
def add_bool(key: str, val: bool | None) -> None:
|
|
if val is not None:
|
|
rows.append({
|
|
"asset_reference_id": reference_id,
|
|
"key": key,
|
|
"ordinal": 0,
|
|
"val_str": None,
|
|
"val_num": None,
|
|
"val_bool": val,
|
|
"val_json": None,
|
|
})
|
|
|
|
# Tier 1
|
|
add_str("filename", self.filename)
|
|
add_num("content_length", self.content_length)
|
|
add_str("content_type", self.content_type)
|
|
add_str("format", self.format)
|
|
|
|
# Tier 2
|
|
add_str("base_model", self.base_model)
|
|
add_str("air", self.air)
|
|
has_previews = self.has_preview_images if self.has_preview_images else None
|
|
add_bool("has_preview_images", has_previews)
|
|
|
|
# trained_words as multiple rows with ordinals
|
|
if self.trained_words:
|
|
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
|
add_str("trained_words", word, ordinal=i)
|
|
|
|
# Source provenance
|
|
add_str("source_url", self.source_url)
|
|
add_str("source_arn", self.source_arn)
|
|
add_str("repo_url", self.repo_url)
|
|
add_str("preview_url", self.preview_url)
|
|
add_str("source_hash", self.source_hash)
|
|
|
|
# HuggingFace
|
|
add_str("repo_id", self.repo_id)
|
|
add_str("revision", self.revision)
|
|
add_str("filepath", self.filepath)
|
|
add_str("resolve_url", self.resolve_url)
|
|
|
|
return rows
|
|
|
|
|
|
def _read_safetensors_header(
|
|
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
|
|
) -> dict[str, Any] | None:
|
|
"""Read only the JSON header from a safetensors file.
|
|
|
|
This is very fast - reads 8 bytes for header length, then the JSON header.
|
|
No tensor data is loaded.
|
|
|
|
Args:
|
|
path: Absolute path to safetensors file
|
|
max_size: Maximum header size to read (default 8MB)
|
|
|
|
Returns:
|
|
Parsed header dict or None if failed
|
|
"""
|
|
try:
|
|
with open(path, "rb") as f:
|
|
header_bytes = f.read(8)
|
|
if len(header_bytes) < 8:
|
|
return None
|
|
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
|
if length_of_header > max_size:
|
|
return None
|
|
header_data = f.read(length_of_header)
|
|
if len(header_data) < length_of_header:
|
|
return None
|
|
return json.loads(header_data.decode("utf-8"))
|
|
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
|
return None
|
|
|
|
|
|
def _extract_safetensors_metadata(
|
|
header: dict[str, Any], meta: ExtractedMetadata
|
|
) -> None:
|
|
"""Extract metadata from safetensors header __metadata__ section.
|
|
|
|
Modifies meta in-place.
|
|
"""
|
|
st_meta = header.get("__metadata__", {})
|
|
if not isinstance(st_meta, dict):
|
|
return
|
|
|
|
# Common model metadata
|
|
meta.base_model = (
|
|
st_meta.get("ss_base_model_version")
|
|
or st_meta.get("modelspec.base_model")
|
|
or st_meta.get("base_model")
|
|
)
|
|
|
|
# Trained words / trigger words
|
|
trained_words = st_meta.get("ss_tag_frequency")
|
|
if trained_words and isinstance(trained_words, str):
|
|
try:
|
|
tag_freq = json.loads(trained_words)
|
|
# Extract unique tags from all datasets
|
|
all_tags: set[str] = set()
|
|
for dataset_tags in tag_freq.values():
|
|
if isinstance(dataset_tags, dict):
|
|
all_tags.update(dataset_tags.keys())
|
|
if all_tags:
|
|
meta.trained_words = sorted(all_tags)[:100]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Direct trained_words field (some formats)
|
|
if not meta.trained_words:
|
|
tw = st_meta.get("trained_words")
|
|
if isinstance(tw, str):
|
|
try:
|
|
meta.trained_words = json.loads(tw)
|
|
except json.JSONDecodeError:
|
|
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
|
elif isinstance(tw, list):
|
|
meta.trained_words = tw
|
|
|
|
# CivitAI AIR
|
|
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
|
|
|
# Preview images (ssmd_cover_images)
|
|
cover_images = st_meta.get("ssmd_cover_images")
|
|
if cover_images:
|
|
meta.has_preview_images = True
|
|
|
|
# Source provenance fields
|
|
meta.source_url = st_meta.get("source_url")
|
|
meta.source_arn = st_meta.get("source_arn")
|
|
meta.repo_url = st_meta.get("repo_url")
|
|
meta.preview_url = st_meta.get("preview_url")
|
|
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
|
|
|
# HuggingFace fields
|
|
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
|
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
|
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
|
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
|
|
|
|
|
def extract_file_metadata(
|
|
abs_path: str,
|
|
stat_result: os.stat_result | None = None,
|
|
enable_safetensors: bool = True,
|
|
relative_filename: str | None = None,
|
|
) -> ExtractedMetadata:
|
|
"""Extract metadata from a file using tier 1 and optionally tier 2 methods.
|
|
|
|
Tier 1 (always): Filesystem metadata from path and stat
|
|
Tier 2 (optional): Safetensors header parsing if applicable
|
|
|
|
Args:
|
|
abs_path: Absolute path to the file
|
|
stat_result: Optional pre-fetched stat result (saves a syscall)
|
|
enable_safetensors: Whether to parse safetensors headers (tier 2)
|
|
relative_filename: Optional relative filename to use instead of basename
|
|
(e.g., "flux/123/model.safetensors" for model paths)
|
|
|
|
Returns:
|
|
ExtractedMetadata with all available fields populated
|
|
"""
|
|
meta = ExtractedMetadata()
|
|
|
|
# Tier 1: Filesystem metadata
|
|
meta.filename = relative_filename or os.path.basename(abs_path)
|
|
meta.file_path = abs_path
|
|
_, ext = os.path.splitext(abs_path)
|
|
meta.format = ext.lstrip(".").lower() if ext else ""
|
|
|
|
# MIME type guess (re-register in case mimetypes.init() was called elsewhere)
|
|
_register_custom_mime_types()
|
|
mime_type, _ = mimetypes.guess_type(abs_path)
|
|
meta.content_type = mime_type
|
|
if mime_type is None:
|
|
pass
|
|
|
|
# Size from stat
|
|
if stat_result is None:
|
|
try:
|
|
stat_result = os.stat(abs_path, follow_symlinks=True)
|
|
except OSError:
|
|
pass
|
|
|
|
if stat_result:
|
|
meta.content_length = stat_result.st_size
|
|
|
|
# Tier 2: Safetensors header (if applicable and enabled)
|
|
if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS:
|
|
header = _read_safetensors_header(abs_path)
|
|
if header:
|
|
try:
|
|
_extract_safetensors_metadata(header, meta)
|
|
except Exception as e:
|
|
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
|
|
|
|
return meta
|