mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-19 06:30:07 +00:00
Amp-Thread-ID: https://ampcode.com/threads/T-019c2fe5-a3de-71cc-a6e5-67fe944a101e Co-authored-by: Amp <amp@ampcode.com>
308 lines
10 KiB
Python
308 lines
10 KiB
Python
"""Metadata extraction for asset scanning.
|
|
|
|
Tier 1: Filesystem metadata (zero parsing)
|
|
Tier 2: Safetensors header metadata (fast JSON read only)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
# Supported safetensors extensions
|
|
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
|
|
|
|
|
@dataclass
|
|
class ExtractedMetadata:
|
|
"""Metadata extracted from a file during scanning."""
|
|
|
|
# Tier 1: Filesystem (always available)
|
|
filename: str = ""
|
|
content_length: int = 0
|
|
content_type: str | None = None
|
|
format: str = "" # file extension without dot
|
|
|
|
# Tier 2: Safetensors header (if available)
|
|
base_model: str | None = None
|
|
trained_words: list[str] | None = None
|
|
air: str | None = None # CivitAI AIR identifier
|
|
has_preview_images: bool = False
|
|
|
|
# Source provenance (populated if embedded in safetensors)
|
|
source_url: str | None = None
|
|
source_arn: str | None = None
|
|
repo_url: str | None = None
|
|
preview_url: str | None = None
|
|
source_hash: str | None = None
|
|
|
|
# HuggingFace specific
|
|
repo_id: str | None = None
|
|
revision: str | None = None
|
|
filepath: str | None = None
|
|
resolve_url: str | None = None
|
|
|
|
def to_user_metadata(self) -> dict[str, Any]:
|
|
"""Convert to user_metadata dict for AssetInfo.user_metadata JSON field."""
|
|
data: dict[str, Any] = {
|
|
"filename": self.filename,
|
|
"content_length": self.content_length,
|
|
"format": self.format,
|
|
}
|
|
if self.content_type:
|
|
data["content_type"] = self.content_type
|
|
|
|
# Tier 2 fields
|
|
if self.base_model:
|
|
data["base_model"] = self.base_model
|
|
if self.trained_words:
|
|
data["trained_words"] = self.trained_words
|
|
if self.air:
|
|
data["air"] = self.air
|
|
if self.has_preview_images:
|
|
data["has_preview_images"] = True
|
|
|
|
# Source provenance
|
|
if self.source_url:
|
|
data["source_url"] = self.source_url
|
|
if self.source_arn:
|
|
data["source_arn"] = self.source_arn
|
|
if self.repo_url:
|
|
data["repo_url"] = self.repo_url
|
|
if self.preview_url:
|
|
data["preview_url"] = self.preview_url
|
|
if self.source_hash:
|
|
data["source_hash"] = self.source_hash
|
|
|
|
# HuggingFace
|
|
if self.repo_id:
|
|
data["repo_id"] = self.repo_id
|
|
if self.revision:
|
|
data["revision"] = self.revision
|
|
if self.filepath:
|
|
data["filepath"] = self.filepath
|
|
if self.resolve_url:
|
|
data["resolve_url"] = self.resolve_url
|
|
|
|
return data
|
|
|
|
def to_meta_rows(self, asset_info_id: str) -> list[dict]:
|
|
"""Convert to asset_info_meta rows for typed/indexed querying."""
|
|
rows: list[dict] = []
|
|
|
|
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
|
if val:
|
|
rows.append({
|
|
"asset_info_id": asset_info_id,
|
|
"key": key,
|
|
"ordinal": ordinal,
|
|
"val_str": val[:2048] if len(val) > 2048 else val,
|
|
"val_num": None,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
})
|
|
|
|
def add_num(key: str, val: int | float | None) -> None:
|
|
if val is not None:
|
|
rows.append({
|
|
"asset_info_id": asset_info_id,
|
|
"key": key,
|
|
"ordinal": 0,
|
|
"val_str": None,
|
|
"val_num": val,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
})
|
|
|
|
def add_bool(key: str, val: bool | None) -> None:
|
|
if val is not None:
|
|
rows.append({
|
|
"asset_info_id": asset_info_id,
|
|
"key": key,
|
|
"ordinal": 0,
|
|
"val_str": None,
|
|
"val_num": None,
|
|
"val_bool": val,
|
|
"val_json": None,
|
|
})
|
|
|
|
# Tier 1
|
|
add_str("filename", self.filename)
|
|
add_num("content_length", self.content_length)
|
|
add_str("content_type", self.content_type)
|
|
add_str("format", self.format)
|
|
|
|
# Tier 2
|
|
add_str("base_model", self.base_model)
|
|
add_str("air", self.air)
|
|
add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None)
|
|
|
|
# trained_words as multiple rows with ordinals
|
|
if self.trained_words:
|
|
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
|
add_str("trained_words", word, ordinal=i)
|
|
|
|
# Source provenance
|
|
add_str("source_url", self.source_url)
|
|
add_str("source_arn", self.source_arn)
|
|
add_str("repo_url", self.repo_url)
|
|
add_str("preview_url", self.preview_url)
|
|
add_str("source_hash", self.source_hash)
|
|
|
|
# HuggingFace
|
|
add_str("repo_id", self.repo_id)
|
|
add_str("revision", self.revision)
|
|
add_str("filepath", self.filepath)
|
|
add_str("resolve_url", self.resolve_url)
|
|
|
|
return rows
|
|
|
|
|
|
def _read_safetensors_header(path: str, max_size: int = 8 * 1024 * 1024) -> dict[str, Any] | None:
|
|
"""Read only the JSON header from a safetensors file.
|
|
|
|
This is very fast - reads 8 bytes for header length, then the JSON header.
|
|
No tensor data is loaded.
|
|
|
|
Args:
|
|
path: Absolute path to safetensors file
|
|
max_size: Maximum header size to read (default 8MB)
|
|
|
|
Returns:
|
|
Parsed header dict or None if failed
|
|
"""
|
|
try:
|
|
with open(path, "rb") as f:
|
|
header_bytes = f.read(8)
|
|
if len(header_bytes) < 8:
|
|
return None
|
|
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
|
if length_of_header > max_size:
|
|
return None
|
|
header_data = f.read(length_of_header)
|
|
if len(header_data) < length_of_header:
|
|
return None
|
|
return json.loads(header_data.decode("utf-8"))
|
|
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
|
return None
|
|
|
|
|
|
def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None:
|
|
"""Extract metadata from safetensors header __metadata__ section.
|
|
|
|
Modifies meta in-place.
|
|
"""
|
|
st_meta = header.get("__metadata__", {})
|
|
if not isinstance(st_meta, dict):
|
|
return
|
|
|
|
# Common model metadata
|
|
meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model")
|
|
|
|
# Trained words / trigger words
|
|
trained_words = st_meta.get("ss_tag_frequency")
|
|
if trained_words and isinstance(trained_words, str):
|
|
try:
|
|
tag_freq = json.loads(trained_words)
|
|
# Extract unique tags from all datasets
|
|
all_tags: set[str] = set()
|
|
for dataset_tags in tag_freq.values():
|
|
if isinstance(dataset_tags, dict):
|
|
all_tags.update(dataset_tags.keys())
|
|
if all_tags:
|
|
meta.trained_words = sorted(all_tags)[:100]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Direct trained_words field (some formats)
|
|
if not meta.trained_words:
|
|
tw = st_meta.get("trained_words")
|
|
if isinstance(tw, str):
|
|
try:
|
|
meta.trained_words = json.loads(tw)
|
|
except json.JSONDecodeError:
|
|
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
|
elif isinstance(tw, list):
|
|
meta.trained_words = tw
|
|
|
|
# CivitAI AIR
|
|
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
|
|
|
# Preview images (ssmd_cover_images)
|
|
cover_images = st_meta.get("ssmd_cover_images")
|
|
if cover_images:
|
|
meta.has_preview_images = True
|
|
|
|
# Source provenance fields
|
|
meta.source_url = st_meta.get("source_url")
|
|
meta.source_arn = st_meta.get("source_arn")
|
|
meta.repo_url = st_meta.get("repo_url")
|
|
meta.preview_url = st_meta.get("preview_url")
|
|
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
|
|
|
# HuggingFace fields
|
|
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
|
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
|
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
|
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
|
|
|
|
|
def extract_file_metadata(
|
|
abs_path: str,
|
|
stat_result: os.stat_result | None = None,
|
|
enable_safetensors: bool = True,
|
|
relative_filename: str | None = None,
|
|
) -> ExtractedMetadata:
|
|
"""Extract metadata from a file using tier 1 and optionally tier 2 methods.
|
|
|
|
Tier 1 (always): Filesystem metadata from path and stat
|
|
Tier 2 (optional): Safetensors header parsing if applicable
|
|
|
|
Args:
|
|
abs_path: Absolute path to the file
|
|
stat_result: Optional pre-fetched stat result (saves a syscall)
|
|
enable_safetensors: Whether to parse safetensors headers (tier 2)
|
|
relative_filename: Optional relative filename to use instead of basename
|
|
(e.g., "flux/123/model.safetensors" for model paths)
|
|
|
|
Returns:
|
|
ExtractedMetadata with all available fields populated
|
|
"""
|
|
meta = ExtractedMetadata()
|
|
|
|
# Tier 1: Filesystem metadata
|
|
# Use relative_filename if provided (for backward compatibility with existing behavior)
|
|
meta.filename = relative_filename if relative_filename else os.path.basename(abs_path)
|
|
_, ext = os.path.splitext(abs_path)
|
|
meta.format = ext.lstrip(".").lower() if ext else ""
|
|
|
|
# MIME type guess
|
|
mime_type, _ = mimetypes.guess_type(abs_path)
|
|
meta.content_type = mime_type
|
|
|
|
# Size from stat
|
|
if stat_result is None:
|
|
try:
|
|
stat_result = os.stat(abs_path, follow_symlinks=True)
|
|
except OSError:
|
|
pass
|
|
|
|
if stat_result:
|
|
meta.content_length = stat_result.st_size
|
|
|
|
# Tier 2: Safetensors header (if applicable and enabled)
|
|
if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS:
|
|
header = _read_safetensors_header(abs_path)
|
|
if header:
|
|
try:
|
|
_extract_safetensors_metadata(header, meta)
|
|
except Exception as e:
|
|
logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e)
|
|
|
|
return meta
|