"""Metadata extraction for asset scanning. Tier 1: Filesystem metadata (zero parsing) Tier 2: Safetensors header metadata (fast JSON read only) """ from __future__ import annotations import json import logging import mimetypes import os import struct from dataclasses import dataclass from typing import Any # Supported safetensors extensions SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) # Maximum safetensors header size to read (8MB) MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 def _register_custom_mime_types(): """Register custom MIME types for model and config files. Called before each use because mimetypes.init() in server.py resets the database. Uses a quick check to avoid redundant registrations. """ # Quick check if already registered (avoids redundant add_type calls) test_result, _ = mimetypes.guess_type("test.safetensors") if test_result == "application/safetensors": return mimetypes.add_type("application/safetensors", ".safetensors") mimetypes.add_type("application/safetensors", ".sft") mimetypes.add_type("application/pytorch", ".pt") mimetypes.add_type("application/pytorch", ".pth") mimetypes.add_type("application/pickle", ".ckpt") mimetypes.add_type("application/pickle", ".pkl") mimetypes.add_type("application/gguf", ".gguf") mimetypes.add_type("application/yaml", ".yaml") mimetypes.add_type("application/yaml", ".yml") # Register custom types at module load _register_custom_mime_types() @dataclass class ExtractedMetadata: """Metadata extracted from a file during scanning.""" # Tier 1: Filesystem (always available) filename: str = "" content_length: int = 0 content_type: str | None = None format: str = "" # file extension without dot # Tier 2: Safetensors header (if available) base_model: str | None = None trained_words: list[str] | None = None air: str | None = None # CivitAI AIR identifier has_preview_images: bool = False # Source provenance (populated if embedded in safetensors) source_url: str | None = None source_arn: str | None = None repo_url: str | None = None preview_url: str | None = None source_hash: str | None = None # HuggingFace specific repo_id: str | None = None revision: str | None = None filepath: str | None = None resolve_url: str | None = None def to_user_metadata(self) -> dict[str, Any]: """Convert to user_metadata dict for AssetInfo.user_metadata JSON field.""" data: dict[str, Any] = { "filename": self.filename, "content_length": self.content_length, "format": self.format, } if self.content_type: data["content_type"] = self.content_type # Tier 2 fields if self.base_model: data["base_model"] = self.base_model if self.trained_words: data["trained_words"] = self.trained_words if self.air: data["air"] = self.air if self.has_preview_images: data["has_preview_images"] = True # Source provenance if self.source_url: data["source_url"] = self.source_url if self.source_arn: data["source_arn"] = self.source_arn if self.repo_url: data["repo_url"] = self.repo_url if self.preview_url: data["preview_url"] = self.preview_url if self.source_hash: data["source_hash"] = self.source_hash # HuggingFace if self.repo_id: data["repo_id"] = self.repo_id if self.revision: data["revision"] = self.revision if self.filepath: data["filepath"] = self.filepath if self.resolve_url: data["resolve_url"] = self.resolve_url return data def to_meta_rows(self, asset_info_id: str) -> list[dict]: """Convert to asset_info_meta rows for typed/indexed querying.""" rows: list[dict] = [] def add_str(key: str, val: str | None, ordinal: int = 0) -> None: if val: rows.append({ "asset_info_id": asset_info_id, "key": key, "ordinal": ordinal, "val_str": val[:2048] if len(val) > 2048 else val, "val_num": None, "val_bool": None, "val_json": None, }) def add_num(key: str, val: int | float | None) -> None: if val is not None: rows.append({ "asset_info_id": asset_info_id, "key": key, "ordinal": 0, "val_str": None, "val_num": val, "val_bool": None, "val_json": None, }) def add_bool(key: str, val: bool | None) -> None: if val is not None: rows.append({ "asset_info_id": asset_info_id, "key": key, "ordinal": 0, "val_str": None, "val_num": None, "val_bool": val, "val_json": None, }) # Tier 1 add_str("filename", self.filename) add_num("content_length", self.content_length) add_str("content_type", self.content_type) add_str("format", self.format) # Tier 2 add_str("base_model", self.base_model) add_str("air", self.air) add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None) # trained_words as multiple rows with ordinals if self.trained_words: for i, word in enumerate(self.trained_words[:100]): # limit to 100 words add_str("trained_words", word, ordinal=i) # Source provenance add_str("source_url", self.source_url) add_str("source_arn", self.source_arn) add_str("repo_url", self.repo_url) add_str("preview_url", self.preview_url) add_str("source_hash", self.source_hash) # HuggingFace add_str("repo_id", self.repo_id) add_str("revision", self.revision) add_str("filepath", self.filepath) add_str("resolve_url", self.resolve_url) return rows def _read_safetensors_header(path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE) -> dict[str, Any] | None: """Read only the JSON header from a safetensors file. This is very fast - reads 8 bytes for header length, then the JSON header. No tensor data is loaded. Args: path: Absolute path to safetensors file max_size: Maximum header size to read (default 8MB) Returns: Parsed header dict or None if failed """ try: with open(path, "rb") as f: header_bytes = f.read(8) if len(header_bytes) < 8: return None length_of_header = struct.unpack(" max_size: return None header_data = f.read(length_of_header) if len(header_data) < length_of_header: return None return json.loads(header_data.decode("utf-8")) except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): return None def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None: """Extract metadata from safetensors header __metadata__ section. Modifies meta in-place. """ st_meta = header.get("__metadata__", {}) if not isinstance(st_meta, dict): return # Common model metadata meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model") # Trained words / trigger words trained_words = st_meta.get("ss_tag_frequency") if trained_words and isinstance(trained_words, str): try: tag_freq = json.loads(trained_words) # Extract unique tags from all datasets all_tags: set[str] = set() for dataset_tags in tag_freq.values(): if isinstance(dataset_tags, dict): all_tags.update(dataset_tags.keys()) if all_tags: meta.trained_words = sorted(all_tags)[:100] except json.JSONDecodeError: pass # Direct trained_words field (some formats) if not meta.trained_words: tw = st_meta.get("trained_words") if isinstance(tw, str): try: meta.trained_words = json.loads(tw) except json.JSONDecodeError: meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] elif isinstance(tw, list): meta.trained_words = tw # CivitAI AIR meta.air = st_meta.get("air") or st_meta.get("modelspec.air") # Preview images (ssmd_cover_images) cover_images = st_meta.get("ssmd_cover_images") if cover_images: meta.has_preview_images = True # Source provenance fields meta.source_url = st_meta.get("source_url") meta.source_arn = st_meta.get("source_arn") meta.repo_url = st_meta.get("repo_url") meta.preview_url = st_meta.get("preview_url") meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") # HuggingFace fields meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") def extract_file_metadata( abs_path: str, stat_result: os.stat_result | None = None, enable_safetensors: bool = True, relative_filename: str | None = None, ) -> ExtractedMetadata: """Extract metadata from a file using tier 1 and optionally tier 2 methods. Tier 1 (always): Filesystem metadata from path and stat Tier 2 (optional): Safetensors header parsing if applicable Args: abs_path: Absolute path to the file stat_result: Optional pre-fetched stat result (saves a syscall) enable_safetensors: Whether to parse safetensors headers (tier 2) relative_filename: Optional relative filename to use instead of basename (e.g., "flux/123/model.safetensors" for model paths) Returns: ExtractedMetadata with all available fields populated """ meta = ExtractedMetadata() # Tier 1: Filesystem metadata # Use relative_filename if provided (for backward compatibility with existing behavior) meta.filename = relative_filename if relative_filename else os.path.basename(abs_path) _, ext = os.path.splitext(abs_path) meta.format = ext.lstrip(".").lower() if ext else "" # MIME type guess (re-register in case mimetypes.init() was called elsewhere) _register_custom_mime_types() mime_type, _ = mimetypes.guess_type(abs_path) meta.content_type = mime_type if mime_type is None: pass # Size from stat if stat_result is None: try: stat_result = os.stat(abs_path, follow_symlinks=True) except OSError: pass if stat_result: meta.content_length = stat_result.st_size # Tier 2: Safetensors header (if applicable and enabled) if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS: header = _read_safetensors_header(abs_path) if header: try: _extract_safetensors_metadata(header, meta) except Exception as e: logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e) return meta