Compare commits

...

11 Commits

Author SHA1 Message Date
Luke Mino-Altherr
6d44ee0448 Include temp output files in asset registration 2026-03-24 19:26:52 -07:00
Luke Mino-Altherr
4266105b9e Revert unintended ruff format changes to preserve original code style 2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
1b6f43eefb Fix enqueue_enrich race, missing tag seeding in ingest, and stale test mtime
- Make enqueue_enrich atomic by moving start_enrich call inside self._lock,
  preventing pending work from being lost when a scan finishes between the
  start attempt and the queue write.
- Call ensure_tags_exist before batch_insert_seed_assets in
  ingest_existing_file to avoid FK violations on asset_reference_tags.
- Fix test_enrich helper to use real file mtime instead of hardcoded value
  so the optimistic staleness guard in enrich_asset passes correctly.
- Add db_engine_fk fixture (SQLite with PRAGMA foreign_keys=ON) and a
  regression test proving ingest_existing_file seeds Tag rows before
  inserting reference tags.
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
83997d228d Fix shared-asset overwrite corruption, stale enrichment race, and path validation
- Detach ref to new stub asset on overwrite when siblings share the asset
- Add optimistic mtime_ns guard in enrich_asset to discard stale results
- Normalize and validate output paths stay under output root, deduplicate
- Skip metadata extraction for stub-only registration (align with fast scan)
- Add RLock comment explaining re-entrant drain requirement
- Log warning when pending enrich drain fails to start
- Add create_stub_asset and count_active_siblings query functions

Amp-Thread-ID: https://ampcode.com/threads/T-019cfe06-f0dc-776f-81ad-e9f3d71be597
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
cb26eb30f2 Fix race in enqueue_enrich drain: make pending-to-start handoff atomic
Change _lock from Lock to RLock and move the start_enrich call inside the
lock-held block so that enqueue_enrich cannot interleave between clearing
_pending_enrich and starting the enrichment scan. This prevents a concurrent
enqueue_enrich from stealing the IDLE slot and causing the drained payload
to be silently dropped.

Add tests covering:
- pending enrich runs after scan completes
- enqueue during drain does not lose work
- concurrent enqueue during drain is queued for the next cycle

Amp-Thread-ID: https://ampcode.com/threads/T-019cfe02-5710-7506-ae80-34bf16c0171a
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
a3522f34c3 Use ExtractedMetadata in ingest_existing_file instead of passing raw dict
Have ingest_existing_file call extract_file_metadata() to build a proper
ExtractedMetadata object, matching what the scanner does. This tightens
SeedAssetSpec.metadata to ExtractedMetadata | None and removes dict-handling
branches in bulk_ingest.py that would have raised AttributeError on
to_meta_rows()/to_user_metadata().

Amp-Thread-ID: https://ampcode.com/threads/T-019cfdf9-2379-723a-82cf-306755e54396
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
e83e1d2edc Update existing asset references on overwritten output files
ingest_existing_file now detects when a reference already exists for the
path and updates mtime_ns, job_id, size_bytes, resets enrichment_level
and clears the asset hash so the enricher re-hashes the new content.

Only brand-new paths fall through to batch_insert_seed_assets.
register_output_files only increments its counter on actual insert or
update.

Amp-Thread-ID: https://ampcode.com/threads/T-019cfdf4-c52c-771c-a920-57bac15c68be
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
e24eddd439 Register output files immediately per-prompt, defer enqueue_enrich to GC
Move register_output_files() out of the periodic GC branch so it runs
right after each prompt completes, using the local e.history_result and
prompt_id. This prevents stale/overwritten values when multiple prompts
finish before GC triggers.

Keep enqueue_enrich() in the GC path since it's heavier and benefits
from batching via the 10-second interval.

Amp-Thread-ID: https://ampcode.com/threads/T-019cfdf4-c52c-771c-a920-57bac15c68be
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
878c03e468 Guard against None base_dir in _collect_output_absolute_paths
Return early with an empty list when folder_paths.get_directory_by_type('output')
returns None to avoid os.path.join(None, ...) producing invalid paths.

Amp-Thread-ID: https://ampcode.com/threads/T-019cfbe5-dffc-760a-9a37-9b041dd71e73
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
88a79075aa fix: preserve user_metadata in ingest_existing_file instead of dropping it
Pass user_metadata through spec['metadata'] to batch_insert_seed_assets.
Update batch_insert_seed_assets to accept raw dicts (UserMetadata) in
addition to ExtractedMetadata, passing them through as-is without
calling .to_user_metadata().

Amp-Thread-ID: https://ampcode.com/threads/T-019cfbe3-6440-724b-a17b-66ce09ecd1ed
2026-03-24 18:34:53 -07:00
Luke Mino-Altherr
e4bfba5978 feat(assets): register output files as assets after prompt execution
Add ingest_existing_file() to services/ingest.py as a public wrapper for
registering on-disk files (stat, BLAKE3 hash, MIME detection, path-based
tag derivation).

After each prompt execution in the main loop, iterate
history_result['outputs'] and register files with type 'output' as
assets. Runs while the asset seeder is paused, gated behind
asset_seeder.is_disabled(). Populates job_id on the asset_references
table for provenance tracking.

Ingest uses a two-phase approach: insert a stub record (hash=NULL) first
for instant visibility, then defer hashing to the background seeder
enrich phase to avoid blocking the prompt worker thread.

When multiple enrich scans are enqueued while the seeder is busy, roots
are now unioned and compute_hashes uses sticky-true (OR) logic so no
queued work is silently dropped.

Extract _reset_to_idle helper in the asset seeder to deduplicate the
state reset pattern shared by _run_scan and mark_missing_outside_prefixes.

Separate history parsing from output file registration: move generic
file registration logic into register_output_files() in
app/assets/services/ingest.py, keeping only the ComfyUI history format
parsing (_collect_output_absolute_paths) in main.py.

Amp-Thread-ID: https://ampcode.com/threads/T-019cf842-5199-71f0-941d-b420b5cf4d57
2026-03-24 18:34:53 -07:00
14 changed files with 764 additions and 14 deletions

View File

@@ -1,6 +1,7 @@
from app.assets.database.queries.asset import ( from app.assets.database.queries.asset import (
asset_exists_by_hash, asset_exists_by_hash,
bulk_insert_assets, bulk_insert_assets,
create_stub_asset,
get_asset_by_hash, get_asset_by_hash,
get_existing_asset_ids, get_existing_asset_ids,
reassign_asset_references, reassign_asset_references,
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
UnenrichedReferenceRow, UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts, bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level, bulk_update_enrichment_level,
count_active_siblings,
bulk_update_is_missing, bulk_update_is_missing,
bulk_update_needs_verify, bulk_update_needs_verify,
convert_metadata_to_rows, convert_metadata_to_rows,
@@ -80,6 +82,8 @@ __all__ = [
"bulk_insert_references_ignore_conflicts", "bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta", "bulk_insert_tags_and_meta",
"bulk_update_enrichment_level", "bulk_update_enrichment_level",
"count_active_siblings",
"create_stub_asset",
"bulk_update_is_missing", "bulk_update_is_missing",
"bulk_update_needs_verify", "bulk_update_needs_verify",
"convert_metadata_to_rows", "convert_metadata_to_rows",

View File

@@ -78,6 +78,18 @@ def upsert_asset(
return asset, created, updated return asset, created, updated
def create_stub_asset(
session: Session,
size_bytes: int,
mime_type: str | None = None,
) -> Asset:
"""Create a new asset with no hash (stub for later enrichment)."""
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
session.add(asset)
session.flush()
return asset
def bulk_insert_assets( def bulk_insert_assets(
session: Session, session: Session,
rows: list[dict], rows: list[dict],

View File

@@ -114,6 +114,23 @@ def get_reference_by_file_path(
) )
def count_active_siblings(
session: Session,
asset_id: str,
exclude_reference_id: str,
) -> int:
"""Count active (non-deleted) references to an asset, excluding one reference."""
return (
session.query(AssetReference)
.filter(
AssetReference.asset_id == asset_id,
AssetReference.id != exclude_reference_id,
AssetReference.deleted_at.is_(None),
)
.count()
)
def reference_exists_for_asset_id( def reference_exists_for_asset_id(
session: Session, session: Session,
asset_id: str, asset_id: str,

View File

@@ -13,6 +13,7 @@ from app.assets.database.queries import (
delete_references_by_ids, delete_references_by_ids,
ensure_tags_exist, ensure_tags_exist,
get_asset_by_hash, get_asset_by_hash,
get_reference_by_id,
get_references_for_prefixes, get_references_for_prefixes,
get_unenriched_references, get_unenriched_references,
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
@@ -338,6 +339,7 @@ def build_asset_specs(
"metadata": metadata, "metadata": metadata,
"hash": asset_hash, "hash": asset_hash,
"mime_type": mime_type, "mime_type": mime_type,
"job_id": None,
} }
) )
tag_pool.update(tags) tag_pool.update(tags)
@@ -426,6 +428,7 @@ def enrich_asset(
except OSError: except OSError:
return new_level return new_level
initial_mtime_ns = get_mtime_ns(stat_p)
rel_fname = compute_relative_filename(file_path) rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None mime_type: str | None = None
metadata = None metadata = None
@@ -489,6 +492,18 @@ def enrich_asset(
except Exception as e: except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)
# Optimistic guard: if the reference's mtime_ns changed since we
# started (e.g. ingest_existing_file updated it), our results are
# stale — discard them to avoid overwriting fresh registration data.
ref = get_reference_by_id(session, reference_id)
if ref is None or ref.mtime_ns != initial_mtime_ns:
session.rollback()
logging.info(
"Ref %s mtime changed during enrichment, discarding stale result",
reference_id,
)
return ENRICHMENT_STUB
if extract_metadata and metadata: if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata() system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata) set_reference_system_metadata(session, reference_id, system_metadata)

View File

@@ -77,7 +77,9 @@ class _AssetSeeder:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._lock = threading.Lock() # RLock is required because _run_scan() drains pending work while
# holding _lock and re-enters start() which also acquires _lock.
self._lock = threading.RLock()
self._state = State.IDLE self._state = State.IDLE
self._progress: Progress | None = None self._progress: Progress | None = None
self._last_progress: Progress | None = None self._last_progress: Progress | None = None
@@ -92,6 +94,7 @@ class _AssetSeeder:
self._prune_first: bool = False self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False self._disabled: bool = False
self._pending_enrich: dict | None = None
def disable(self) -> None: def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting.""" """Disable the asset seeder, preventing any scans from starting."""
@@ -196,6 +199,42 @@ class _AssetSeeder:
compute_hashes=compute_hashes, compute_hashes=compute_hashes,
) )
def enqueue_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan now, or queue it for after the current scan.
If the seeder is idle, starts immediately. Otherwise, the enrich
request is stored and will run automatically when the current scan
finishes.
Args:
roots: Tuple of root types to scan
compute_hashes: If True, compute blake3 hashes
Returns:
True if started immediately, False if queued for later
"""
with self._lock:
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
return True
if self._pending_enrich is not None:
existing_roots = set(self._pending_enrich["roots"])
existing_roots.update(roots)
self._pending_enrich["roots"] = tuple(existing_roots)
self._pending_enrich["compute_hashes"] = (
self._pending_enrich["compute_hashes"] or compute_hashes
)
else:
self._pending_enrich = {
"roots": roots,
"compute_hashes": compute_hashes,
}
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
return False
def cancel(self) -> bool: def cancel(self) -> bool:
"""Request cancellation of the current scan. """Request cancellation of the current scan.
@@ -381,9 +420,13 @@ class _AssetSeeder:
return marked return marked
finally: finally:
with self._lock: with self._lock:
self._last_progress = self._progress self._reset_to_idle()
self._state = State.IDLE
self._progress = None def _reset_to_idle(self) -> None:
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool: def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested.""" """Check if cancellation has been requested."""
@@ -594,9 +637,18 @@ class _AssetSeeder:
}, },
) )
with self._lock: with self._lock:
self._last_progress = self._progress self._reset_to_idle()
self._state = State.IDLE pending = self._pending_enrich
self._progress = None if pending is not None:
self._pending_enrich = None
if not self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
):
logging.warning(
"Pending enrich scan could not start (roots=%s)",
pending["roots"],
)
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records. """Run phase 1: fast scan to create stub records.

View File

@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
DependencyMissingError, DependencyMissingError,
HashMismatchError, HashMismatchError,
create_from_hash, create_from_hash,
ingest_existing_file,
register_output_files,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.database.queries import ( from app.assets.database.queries import (
@@ -72,6 +74,8 @@ __all__ = [
"delete_asset_reference", "delete_asset_reference",
"get_asset_by_hash", "get_asset_by_hash",
"get_asset_detail", "get_asset_detail",
"ingest_existing_file",
"register_output_files",
"get_mtime_ns", "get_mtime_ns",
"get_size_and_mtime_ns", "get_size_and_mtime_ns",
"list_assets_page", "list_assets_page",

View File

@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
metadata: ExtractedMetadata | None metadata: ExtractedMetadata | None
hash: str | None hash: str | None
mime_type: str | None mime_type: str | None
job_id: str | None
class AssetRow(TypedDict): class AssetRow(TypedDict):
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
name: str name: str
preview_id: str | None preview_id: str | None
user_metadata: dict[str, Any] | None user_metadata: dict[str, Any] | None
job_id: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
last_access_time: datetime last_access_time: datetime
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
"name": spec["info_name"], "name": spec["info_name"],
"preview_id": None, "preview_id": None,
"user_metadata": user_metadata, "user_metadata": user_metadata,
"job_id": spec.get("job_id"),
"created_at": current_time, "created_at": current_time,
"updated_at": current_time, "updated_at": current_time,
"last_access_time": current_time, "last_access_time": current_time,

View File

@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing import app.assets.services.hashing as hashing
from app.assets.database.queries import ( from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
count_active_siblings,
create_stub_asset,
ensure_tags_exist,
fetch_reference_and_asset, fetch_reference_and_asset,
get_asset_by_hash, get_asset_by_hash,
get_reference_by_file_path, get_reference_by_file_path,
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
upsert_reference, upsert_reference,
validate_tags_exist, validate_tags_exist,
) )
from app.assets.helpers import normalize_tags from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
) )
def register_output_files(
file_paths: Sequence[str],
user_metadata: UserMetadata = None,
job_id: str | None = None,
) -> int:
"""Register a batch of output file paths as assets.
Returns the number of files successfully registered.
"""
registered = 0
for abs_path in file_paths:
if not os.path.isfile(abs_path):
continue
try:
if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id
):
registered += 1
except Exception:
logging.exception("Failed to register output: %s", abs_path)
return registered
def ingest_existing_file(
abs_path: str,
user_metadata: UserMetadata = None,
extra_tags: Sequence[str] = (),
owner_id: str = "",
job_id: str | None = None,
) -> bool:
"""Register an existing on-disk file as an asset stub.
If a reference already exists for this path, updates mtime_ns, job_id,
size_bytes, and resets enrichment so the enricher will re-hash it.
For brand-new paths, inserts a stub record (hash=NULL) for immediate
UX visibility.
Returns True if a row was inserted or updated, False otherwise.
"""
locator = os.path.abspath(abs_path)
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
with create_session() as session:
existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None:
now = get_utc_now()
existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id
existing_ref.is_missing = False
existing_ref.deleted_at = None
existing_ref.updated_at = now
existing_ref.enrichment_level = 0
asset = existing_ref.asset
if asset:
# If other refs share this asset, detach to a new stub
# instead of mutating the shared row.
siblings = count_active_siblings(session, asset.id, existing_ref.id)
if siblings > 0:
new_asset = create_stub_asset(
session,
size_bytes=size_bytes,
mime_type=mime_type or asset.mime_type,
)
existing_ref.asset_id = new_asset.id
else:
asset.hash = None
asset.size_bytes = size_bytes
if mime_type:
asset.mime_type = mime_type
session.commit()
return True
spec = {
"abs_path": abs_path,
"size_bytes": size_bytes,
"mtime_ns": mtime_ns,
"info_name": name,
"tags": tags,
"fname": os.path.basename(abs_path),
"metadata": None,
"hash": None,
"mime_type": mime_type,
"job_id": job_id,
}
if tags:
ensure_tags_exist(session, tags)
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
session.commit()
return result.won_paths > 0
def _register_existing_asset( def _register_existing_asset(
asset_hash: str, asset_hash: str,
name: str, name: str,

43
main.py
View File

@@ -9,6 +9,8 @@ import folder_paths
import time import time
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger from app.logger import setup_logger
from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files
import itertools import itertools
import utils.extra_config import utils.extra_config
from utils.mime_types import init_mime_types from utils.mime_types import init_mime_types
@@ -192,7 +194,6 @@ if 'torch' in sys.modules:
import comfy.utils import comfy.utils
from app.assets.seeder import asset_seeder
import execution import execution
import server import server
@@ -240,6 +241,38 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
"""Extract absolute file paths for output items from a history result."""
paths: list[str] = []
seen: set[str] = set()
for node_output in history_result.get("outputs", {}).values():
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
@@ -274,6 +307,7 @@ def prompt_worker(q, server_instance):
asset_seeder.pause() asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4]) e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
@@ -296,6 +330,10 @@ def prompt_worker(q, server_instance):
else: else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
if not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id)
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)
@@ -317,6 +355,9 @@ def prompt_worker(q, server_instance):
last_gc_collect = current_time last_gc_collect = current_time
need_gc = False need_gc = False
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.resume() asset_seeder.resume()

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import Base from app.assets.database.models import Base
@@ -23,6 +23,21 @@ def db_engine():
return engine return engine
@pytest.fixture
def db_engine_fk():
"""In-memory SQLite engine with foreign key enforcement enabled."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def _set_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
return engine
@pytest.fixture @pytest.fixture
def session(db_engine): def session(db_engine):
"""Session fixture for tests that need direct DB access.""" """Session fixture for tests that need direct DB access."""

View File

@@ -1,9 +1,11 @@
"""Tests for asset enrichment (mime_type and hash population).""" """Tests for asset enrichment (mime_type and hash population)."""
import os
from pathlib import Path from pathlib import Path
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference from app.assets.database.models import Asset, AssetReference
from app.assets.services.file_utils import get_mtime_ns
from app.assets.scanner import ( from app.assets.scanner import (
ENRICHMENT_HASHED, ENRICHMENT_HASHED,
ENRICHMENT_METADATA, ENRICHMENT_METADATA,
@@ -20,6 +22,13 @@ def _create_stub_asset(
name: str | None = None, name: str | None = None,
) -> tuple[Asset, AssetReference]: ) -> tuple[Asset, AssetReference]:
"""Create a stub asset with reference for testing enrichment.""" """Create a stub asset with reference for testing enrichment."""
# Use the real file's mtime so the optimistic guard in enrich_asset passes
try:
stat_result = os.stat(file_path, follow_symlinks=True)
mtime_ns = get_mtime_ns(stat_result)
except OSError:
mtime_ns = 1234567890000000000
asset = Asset( asset = Asset(
id=asset_id, id=asset_id,
hash=None, hash=None,
@@ -35,7 +44,7 @@ def _create_stub_asset(
name=name or f"test-asset-{asset_id}", name=name or f"test-asset-{asset_id}",
owner_id="system", owner_id="system",
file_path=file_path, file_path=file_path,
mtime_ns=1234567890000000000, mtime_ns=mtime_ns,
enrichment_level=ENRICHMENT_STUB, enrichment_level=ENRICHMENT_STUB,
) )
session.add(ref) session.add(ref)

View File

@@ -1,12 +1,18 @@
"""Tests for ingest services.""" """Tests for ingest services."""
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, Tag from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags from app.assets.database.queries import get_reference_tags
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
ingest_existing_file,
)
class TestIngestFileFromPath: class TestIngestFileFromPath:
@@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
assert result.created is True assert result.created is True
assert set(result.tags) == {"alpha", "beta"} assert set(result.tags) == {"alpha", "beta"}
class TestIngestExistingFileTagFK:
"""Regression: ingest_existing_file must seed Tag rows before inserting
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
before they can be referenced in asset_reference_tags."""
@contextmanager
def _create_session():
with SASession(db_engine_fk) as sess:
yield sess
file_path = temp_dir / "output.png"
file_path.write_bytes(b"image data")
with patch("app.assets.services.ingest.create_session", _create_session), \
patch(
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
return_value=("output.png", ["output"]),
):
result = ingest_existing_file(
abs_path=str(file_path),
extra_tags=["my-job"],
)
assert result is True
with SASession(db_engine_fk) as sess:
tag_names = {t.name for t in sess.query(Tag).all()}
assert "output" in tag_names
assert "my-job" in tag_names
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
assert "my-job" in ref_tag_names

View File

@@ -1,6 +1,7 @@
"""Unit tests for the _AssetSeeder background scanning class.""" """Unit tests for the _AssetSeeder background scanning class."""
import threading import threading
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -771,6 +772,188 @@ class TestSeederStopRestart:
assert collected_roots[1] == ("input",) assert collected_roots[1] == ("input",)
class TestEnqueueEnrichHandoff:
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
def test_pending_enrich_runs_after_scan_completes(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""A queued enrich request runs automatically when a scan finishes."""
enrich_roots_seen: list[tuple] = []
original_start = fresh_seeder.start
def tracking_start(*args, **kwargs):
phase = kwargs.get("phase")
roots = kwargs.get("roots", args[0] if args else None)
result = original_start(*args, **kwargs)
if phase == ScanPhase.ENRICH and result:
enrich_roots_seen.append(roots)
return result
fresh_seeder.start = tracking_start
# Start a fast scan, then enqueue an enrich while it's running
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=True
)
assert queued is False # queued, not started immediately
barrier.set()
# Wait for the original scan + the auto-started enrich scan
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
assert enrich_roots_seen == [("input",)]
def test_enqueue_enrich_during_drain_does_not_lose_work(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""enqueue_enrich called concurrently with drain cannot drop work.
Simulates the race: another thread calls enqueue_enrich right as the
scan thread is draining _pending_enrich. The enqueue must either be
picked up by the draining scan or successfully start its own scan.
"""
barrier = threading.Event()
reached = threading.Event()
enrich_started = threading.Event()
enrich_call_count = 0
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
# Track how many times start_enrich actually fires
real_start_enrich = fresh_seeder.start_enrich
enrich_roots_seen: list[tuple] = []
def tracking_start_enrich(**kwargs):
nonlocal enrich_call_count
enrich_call_count += 1
enrich_roots_seen.append(kwargs.get("roots"))
result = real_start_enrich(**kwargs)
if result:
enrich_started.set()
return result
fresh_seeder.start_enrich = tracking_start_enrich
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
# Start a scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
# Queue an enrich while scan is running
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
# Let scan finish — drain will fire start_enrich atomically
barrier.set()
# Wait for drain to complete and the enrich scan to start
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
assert ("output",) in enrich_roots_seen
def test_concurrent_enqueue_during_drain_not_lost(
self, fresh_seeder: _AssetSeeder,
):
"""A second enqueue_enrich arriving while drain is in progress is not lost.
Because the drain now holds _lock through the start_enrich call,
a concurrent enqueue_enrich will block until start_enrich has
transitioned state to RUNNING, then the enqueue will queue its
payload as _pending_enrich for the *next* drain.
"""
scan_barrier = threading.Event()
scan_reached = threading.Event()
enrich_barrier = threading.Event()
enrich_reached = threading.Event()
collect_call = 0
def gated_collect(*args):
nonlocal collect_call
collect_call += 1
if collect_call == 1:
# First call: the initial fast scan
scan_reached.set()
scan_barrier.wait(timeout=5.0)
return []
enrich_call = 0
def gated_get_unenriched(*args, **kwargs):
nonlocal enrich_call
enrich_call += 1
if enrich_call == 1:
# First enrich batch: signal and block
enrich_reached.set()
enrich_barrier.wait(timeout=5.0)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
# 1. Start fast scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert scan_reached.wait(timeout=2.0)
# 2. Queue enrich while fast scan is running
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=False
)
assert queued is False
# 3. Let the fast scan finish — drain will start the enrich scan
scan_barrier.set()
# 4. Wait until the drained enrich scan is running
assert enrich_reached.wait(timeout=5.0)
# 5. Now enqueue another enrich while the drained scan is running
queued2 = fresh_seeder.enqueue_enrich(
roots=("output",), compute_hashes=True
)
assert queued2 is False # should be queued, not started
# Verify _pending_enrich was set (the second enqueue was captured)
with fresh_seeder._lock:
assert fresh_seeder._pending_enrich is not None
assert "output" in fresh_seeder._pending_enrich["roots"]
# Let the enrich scan finish
enrich_barrier.set()
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow( return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id, reference_id=ref_id, asset_id=asset_id,

250
tests/test_asset_seeder.py Normal file
View File

@@ -0,0 +1,250 @@
"""Tests for app.assets.seeder enqueue_enrich and pending-queue behaviour."""
import threading
from unittest.mock import patch
import pytest
from app.assets.seeder import Progress, _AssetSeeder, State
@pytest.fixture()
def seeder():
"""Fresh seeder instance for each test."""
return _AssetSeeder()
# ---------------------------------------------------------------------------
# _reset_to_idle
# ---------------------------------------------------------------------------
class TestResetToIdle:
def test_sets_idle_and_clears_progress(self, seeder):
"""_reset_to_idle should move state to IDLE and snapshot progress."""
progress = Progress(scanned=10, total=20, created=5, skipped=3)
seeder._state = State.RUNNING
seeder._progress = progress
with seeder._lock:
seeder._reset_to_idle()
assert seeder._state is State.IDLE
assert seeder._progress is None
assert seeder._last_progress is progress
def test_noop_when_progress_already_none(self, seeder):
"""_reset_to_idle should handle None progress gracefully."""
seeder._state = State.CANCELLING
seeder._progress = None
with seeder._lock:
seeder._reset_to_idle()
assert seeder._state is State.IDLE
assert seeder._progress is None
assert seeder._last_progress is None
# ---------------------------------------------------------------------------
# enqueue_enrich immediate start when idle
# ---------------------------------------------------------------------------
class TestEnqueueEnrichStartsImmediately:
def test_starts_when_idle(self, seeder):
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
with patch.object(seeder, "start_enrich", return_value=True) as mock:
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
def test_no_pending_when_started_immediately(self, seeder):
"""No pending request should be stored when start_enrich succeeds."""
with patch.object(seeder, "start_enrich", return_value=True):
seeder.enqueue_enrich(roots=("output",))
assert seeder._pending_enrich is None
# ---------------------------------------------------------------------------
# enqueue_enrich queuing when busy
# ---------------------------------------------------------------------------
class TestEnqueueEnrichQueuesWhenBusy:
def test_queues_when_busy(self, seeder):
"""enqueue_enrich should store a pending request when seeder is busy."""
with patch.object(seeder, "start_enrich", return_value=False):
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
assert result is False
assert seeder._pending_enrich == {
"roots": ("models",),
"compute_hashes": False,
}
def test_queues_preserves_compute_hashes_true(self, seeder):
with patch.object(seeder, "start_enrich", return_value=False):
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
assert seeder._pending_enrich["compute_hashes"] is True
# ---------------------------------------------------------------------------
# enqueue_enrich merging when a pending request already exists
# ---------------------------------------------------------------------------
class TestEnqueueEnrichMergesPending:
def _make_busy(self, seeder):
"""Patch start_enrich to always return False (seeder busy)."""
return patch.object(seeder, "start_enrich", return_value=False)
def test_merges_roots(self, seeder):
"""A second enqueue should merge roots with the existing pending request."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",))
seeder.enqueue_enrich(roots=("output",))
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "output"}
def test_merges_overlapping_roots(self, seeder):
"""Duplicate roots should be deduplicated."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models", "input"))
seeder.enqueue_enrich(roots=("input", "output"))
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}
def test_compute_hashes_sticky_true(self, seeder):
"""Once compute_hashes is True it should stay True after merging."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
assert seeder._pending_enrich["compute_hashes"] is True
def test_compute_hashes_upgrades_to_true(self, seeder):
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
assert seeder._pending_enrich["compute_hashes"] is True
def test_compute_hashes_stays_false(self, seeder):
"""If both enqueues have compute_hashes=False it stays False."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
assert seeder._pending_enrich["compute_hashes"] is False
def test_triple_merge(self, seeder):
"""Three successive enqueues should all merge correctly."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}
assert seeder._pending_enrich["compute_hashes"] is True
# ---------------------------------------------------------------------------
# Pending enrich drains after scan completes
# ---------------------------------------------------------------------------
class TestPendingEnrichDrain:
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_pending_enrich_starts_after_scan(self, *_mocks):
"""After a fast scan finishes, the pending enrich should be started."""
seeder = _AssetSeeder()
seeder._pending_enrich = {
"roots": ("output",),
"compute_hashes": True,
}
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
mock_start.assert_called_once_with(
roots=("output",),
compute_hashes=True,
)
assert seeder._pending_enrich is None
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_pending_cleared_even_when_start_fails(self, *_mocks):
"""_pending_enrich should be cleared even if start_enrich returns False."""
seeder = _AssetSeeder()
seeder._pending_enrich = {
"roots": ("output",),
"compute_hashes": False,
}
with patch.object(seeder, "start_enrich", return_value=False):
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
assert seeder._pending_enrich is None
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_no_drain_when_no_pending(self, *_mocks):
"""start_enrich should not be called when there is no pending request."""
seeder = _AssetSeeder()
assert seeder._pending_enrich is None
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
mock_start.assert_not_called()
# ---------------------------------------------------------------------------
# Thread-safety of enqueue_enrich
# ---------------------------------------------------------------------------
class TestEnqueueEnrichThreadSafety:
def test_concurrent_enqueues(self, seeder):
"""Multiple threads enqueuing should not lose roots."""
with patch.object(seeder, "start_enrich", return_value=False):
barrier = threading.Barrier(3)
def enqueue(root):
barrier.wait()
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
threads = [
threading.Thread(target=enqueue, args=(r,))
for r in ("models", "input", "output")
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}