Make ingest_file_from_path and register_existing_asset private

Amp-Thread-ID: https://ampcode.com/threads/T-019c2fe5-a3de-71cc-a6e5-67fe944a101e
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-02-05 14:26:36 -08:00
parent 56e9a75ca2
commit 28c4b58dd6
8 changed files with 783 additions and 26 deletions

View File

@@ -24,6 +24,7 @@ from app.assets.services.file_utils import (
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
@@ -249,8 +250,15 @@ def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
def _build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count)."""
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata from files
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
@@ -267,6 +275,18 @@ def _build_asset_specs(
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
if enable_metadata_extraction:
metadata = extract_file_metadata(
abs_p,
stat_result=stat_p,
enable_safetensors=True,
relative_filename=rel_fname,
)
specs.append(
{
"abs_path": abs_p,
@@ -274,7 +294,8 @@ def _build_asset_specs(
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
"fname": rel_fname,
"metadata": metadata,
}
)
tag_pool.update(tags)

View File

@@ -23,8 +23,6 @@ from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_file_from_path,
register_existing_asset,
upload_from_temp_path,
)
from app.assets.services.schemas import (
@@ -76,12 +74,10 @@ __all__ = [
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"ingest_file_from_path",
"list_assets_page",
"list_files_recursively",
"list_tags",
"prune_orphaned_assets",
"register_existing_asset",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",

View File

@@ -1,10 +1,15 @@
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from typing import TypedDict
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
@@ -15,6 +20,7 @@ class SeedAssetSpec(TypedDict):
info_name: str
tags: list[str]
fname: str
metadata: ExtractedMetadata | None
from app.assets.database.queries import (
bulk_insert_asset_infos_ignore_conflicts,
@@ -98,18 +104,28 @@ def batch_insert_seed_assets(
"mtime_ns": sp["mtime_ns"],
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted = sp.get("metadata")
if extracted:
user_metadata: dict[str, Any] | None = extracted.to_user_metadata()
elif sp["fname"]:
user_metadata = {"filename": sp["fname"]}
else:
user_metadata = None
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"user_metadata": user_metadata,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
"_extracted_metadata": extracted,
}
bulk_insert_assets(session, asset_rows)
@@ -166,7 +182,13 @@ def batch_insert_seed_assets(
"added_at": now,
}
)
if row["_filename"]:
# Use extracted metadata for meta rows if available
extracted = row.get("_extracted_metadata")
if extracted:
meta_rows.extend(extracted.to_meta_rows(iid))
elif row["_filename"]:
# Fallback: just store filename
meta_rows.append(
{
"asset_info_id": iid,

View File

@@ -40,7 +40,7 @@ from app.assets.services.schemas import (
from app.database.db import create_session
def ingest_file_from_path(
def _ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
@@ -134,7 +134,7 @@ def ingest_file_from_path(
)
def register_existing_asset(
def _register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
@@ -278,7 +278,7 @@ def upload_from_temp_path(
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = register_existing_asset(
result = _register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
@@ -320,7 +320,7 @@ def upload_from_temp_path(
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = ingest_file_from_path(
ingest_result = _ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
@@ -369,7 +369,7 @@ def create_from_hash(
if not asset:
return None
result = register_existing_asset(
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical

View File

@@ -0,0 +1,307 @@
"""Metadata extraction for asset scanning.
Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging
import mimetypes
import os
import struct
from dataclasses import dataclass
from typing import Any
# Supported safetensors extensions
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
@dataclass
class ExtractedMetadata:
"""Metadata extracted from a file during scanning."""
# Tier 1: Filesystem (always available)
filename: str = ""
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
# Tier 2: Safetensors header (if available)
base_model: str | None = None
trained_words: list[str] | None = None
air: str | None = None # CivitAI AIR identifier
has_preview_images: bool = False
# Source provenance (populated if embedded in safetensors)
source_url: str | None = None
source_arn: str | None = None
repo_url: str | None = None
preview_url: str | None = None
source_hash: str | None = None
# HuggingFace specific
repo_id: str | None = None
revision: str | None = None
filepath: str | None = None
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetInfo.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.content_type:
data["content_type"] = self.content_type
# Tier 2 fields
if self.base_model:
data["base_model"] = self.base_model
if self.trained_words:
data["trained_words"] = self.trained_words
if self.air:
data["air"] = self.air
if self.has_preview_images:
data["has_preview_images"] = True
# Source provenance
if self.source_url:
data["source_url"] = self.source_url
if self.source_arn:
data["source_arn"] = self.source_arn
if self.repo_url:
data["repo_url"] = self.repo_url
if self.preview_url:
data["preview_url"] = self.preview_url
if self.source_hash:
data["source_hash"] = self.source_hash
# HuggingFace
if self.repo_id:
data["repo_id"] = self.repo_id
if self.revision:
data["revision"] = self.revision
if self.filepath:
data["filepath"] = self.filepath
if self.resolve_url:
data["resolve_url"] = self.resolve_url
return data
def to_meta_rows(self, asset_info_id: str) -> list[dict]:
"""Convert to asset_info_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
"val_num": None,
"val_bool": None,
"val_json": None,
})
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": val,
"val_bool": None,
"val_json": None,
})
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": None,
"val_bool": val,
"val_json": None,
})
# Tier 1
add_str("filename", self.filename)
add_num("content_length", self.content_length)
add_str("content_type", self.content_type)
add_str("format", self.format)
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None)
# trained_words as multiple rows with ordinals
if self.trained_words:
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
add_str("trained_words", word, ordinal=i)
# Source provenance
add_str("source_url", self.source_url)
add_str("source_arn", self.source_arn)
add_str("repo_url", self.repo_url)
add_str("preview_url", self.preview_url)
add_str("source_hash", self.source_hash)
# HuggingFace
add_str("repo_id", self.repo_id)
add_str("revision", self.revision)
add_str("filepath", self.filepath)
add_str("resolve_url", self.resolve_url)
return rows
def _read_safetensors_header(path: str, max_size: int = 8 * 1024 * 1024) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
No tensor data is loaded.
Args:
path: Absolute path to safetensors file
max_size: Maximum header size to read (default 8MB)
Returns:
Parsed header dict or None if failed
"""
try:
with open(path, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) < 8:
return None
length_of_header = struct.unpack("<Q", header_bytes)[0]
if length_of_header > max_size:
return None
header_data = f.read(length_of_header)
if len(header_data) < length_of_header:
return None
return json.loads(header_data.decode("utf-8"))
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
return None
def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
"""
st_meta = header.get("__metadata__", {})
if not isinstance(st_meta, dict):
return
# Common model metadata
meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model")
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
if trained_words and isinstance(trained_words, str):
try:
tag_freq = json.loads(trained_words)
# Extract unique tags from all datasets
all_tags: set[str] = set()
for dataset_tags in tag_freq.values():
if isinstance(dataset_tags, dict):
all_tags.update(dataset_tags.keys())
if all_tags:
meta.trained_words = sorted(all_tags)[:100]
except json.JSONDecodeError:
pass
# Direct trained_words field (some formats)
if not meta.trained_words:
tw = st_meta.get("trained_words")
if isinstance(tw, str):
try:
meta.trained_words = json.loads(tw)
except json.JSONDecodeError:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
elif isinstance(tw, list):
meta.trained_words = tw
# CivitAI AIR
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
# Preview images (ssmd_cover_images)
cover_images = st_meta.get("ssmd_cover_images")
if cover_images:
meta.has_preview_images = True
# Source provenance fields
meta.source_url = st_meta.get("source_url")
meta.source_arn = st_meta.get("source_arn")
meta.repo_url = st_meta.get("repo_url")
meta.preview_url = st_meta.get("preview_url")
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
# HuggingFace fields
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
def extract_file_metadata(
abs_path: str,
stat_result: os.stat_result | None = None,
enable_safetensors: bool = True,
relative_filename: str | None = None,
) -> ExtractedMetadata:
"""Extract metadata from a file using tier 1 and optionally tier 2 methods.
Tier 1 (always): Filesystem metadata from path and stat
Tier 2 (optional): Safetensors header parsing if applicable
Args:
abs_path: Absolute path to the file
stat_result: Optional pre-fetched stat result (saves a syscall)
enable_safetensors: Whether to parse safetensors headers (tier 2)
relative_filename: Optional relative filename to use instead of basename
(e.g., "flux/123/model.safetensors" for model paths)
Returns:
ExtractedMetadata with all available fields populated
"""
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
# Use relative_filename if provided (for backward compatibility with existing behavior)
meta.filename = relative_filename if relative_filename else os.path.basename(abs_path)
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
# MIME type guess
mime_type, _ = mimetypes.guess_type(abs_path)
meta.content_type = mime_type
# Size from stat
if stat_result is None:
try:
stat_result = os.stat(abs_path, follow_symlinks=True)
except OSError:
pass
if stat_result:
meta.content_length = stat_result.st_size
# Tier 2: Safetensors header (if applicable and enabled)
if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS:
header = _read_safetensors_header(abs_path)
if header:
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e)
return meta