fix: address code review feedback

- Fix missing import for compute_filename_for_reference in ingest.py
- Apply code review fixes across routes, queries, scanner, seeder,
  hashing, ingest, path_utils, main, and server
- Update and add tests for sync references and seeder

Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-03 15:51:35 -08:00
parent 3232f48a41
commit 4d4c2cedd3
13 changed files with 675 additions and 218 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import functools
import json
import logging
import os
@@ -39,6 +40,20 @@ from app.assets.services import (
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
_ASSETS_ENABLED = False
def _require_assets_feature_enabled(handler):
@functools.wraps(handler)
async def wrapper(request: web.Request) -> web.Response:
if not _ASSETS_ENABLED:
return _build_error_response(
503,
"SERVICE_DISABLED",
"Assets system is disabled. Start the server with --enable-assets to use this feature.",
)
return await handler(request)
return wrapper
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@@ -64,11 +79,13 @@ def get_query_dict(request: web.Request) -> dict[str, Any]:
# do not rely on the code in /app/assets remaining the same.
def register_assets_system(
app: web.Application, user_manager_instance: user_manager.UserManager
def register_assets_routes(
app: web.Application, user_manager_instance: user_manager.UserManager | None = None,
) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
global USER_MANAGER, _ASSETS_ENABLED
if user_manager_instance is not None:
USER_MANAGER = user_manager_instance
_ASSETS_ENABLED = True
app.add_routes(ROUTES)
@@ -96,6 +113,7 @@ def _validate_sort_field(requested: str | None) -> str:
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
@@ -116,6 +134,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets")
@_require_assets_feature_enabled
async def list_assets_route(request: web.Request) -> web.Response:
"""
GET request to list assets.
@@ -166,6 +185,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def get_asset_route(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
@@ -211,6 +231,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@_require_assets_feature_enabled
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
@@ -264,6 +285,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash")
@_require_assets_feature_enabled
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
try:
payload = await request.json()
@@ -304,6 +326,7 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets")
@_require_assets_feature_enabled
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
try:
@@ -408,6 +431,7 @@ async def upload_asset(request: web.Request) -> web.Response:
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def update_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@@ -453,6 +477,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def delete_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content")
@@ -484,6 +509,7 @@ async def delete_asset_route(request: web.Request) -> web.Response:
@ROUTES.get("/api/tags")
@_require_assets_feature_enabled
async def get_tags(request: web.Request) -> web.Response:
"""
GET request to list all tags based on query parameters.
@@ -520,6 +546,7 @@ async def get_tags(request: web.Request) -> web.Response:
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def add_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@@ -569,6 +596,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def delete_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@@ -613,6 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output).
@@ -662,6 +691,7 @@ async def seed_assets(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets/seed/status")
@_require_assets_feature_enabled
async def get_seed_status(request: web.Request) -> web.Response:
"""Get current scan status and progress."""
status = asset_seeder.get_status()
@@ -683,6 +713,7 @@ async def get_seed_status(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed/cancel")
@_require_assets_feature_enabled
async def cancel_seed(request: web.Request) -> web.Response:
"""Request cancellation of in-progress scan."""
cancelled = asset_seeder.cancel()
@@ -692,6 +723,7 @@ async def cancel_seed(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/prune")
@_require_assets_feature_enabled
async def mark_missing_assets(request: web.Request) -> web.Response:
"""Mark assets as missing when outside all known root prefixes.

View File

@@ -57,6 +57,7 @@ from app.assets.database.queries.tags import (
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
validate_tags_exist,
)
__all__ = [
@@ -114,4 +115,5 @@ __all__ = [
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
"validate_tags_exist",
]

View File

@@ -660,13 +660,16 @@ def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
if not file_paths:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.in_(file_paths))
.where(AssetReference.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
total = 0
for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS):
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.in_(chunk))
.where(AssetReference.is_missing == True) # noqa: E712
.values(is_missing=False)
)
total += result.rowcount
return total
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
@@ -697,11 +700,14 @@ def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""
if not asset_ids:
return 0
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids))
)
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
total = 0
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk))
)
result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk)))
total += result.rowcount
return total
def get_references_for_prefixes(

View File

@@ -37,6 +37,17 @@ class SetTagsDict(TypedDict):
total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:

View File

@@ -44,9 +44,9 @@ from app.database.db import create_session, dependencies_available
class _RefInfo(TypedDict):
ref_id: str
fp: str
file_path: str
exists: bool
fast_ok: bool
stat_unchanged: bool
needs_verify: bool
@@ -75,9 +75,7 @@ def get_prefixes_for_root(root: RootType) -> list[str]:
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
]
return [p for root in all_roots for p in get_prefixes_for_root(root)]
def collect_models_files() -> list[str]:
@@ -110,10 +108,10 @@ def sync_references_with_filesystem(
) -> set[str] | None:
"""Reconcile asset references with filesystem for a root.
- Toggle needs_verify per reference using fast mtime/size check
- For hashed assets with at least one fast-ok ref: delete stale missing refs
- Toggle needs_verify per reference using mtime/size stat check
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths
Args:
@@ -140,10 +138,10 @@ def sync_references_with_filesystem(
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc
fast_ok = False
stat_unchanged = False
try:
exists = True
fast_ok = verify_file_unchanged(
stat_unchanged = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
@@ -160,9 +158,9 @@ def sync_references_with_filesystem(
acc["refs"].append(
{
"ref_id": row.reference_id,
"fp": row.file_path,
"file_path": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
"stat_unchanged": stat_unchanged,
"needs_verify": row.needs_verify,
}
)
@@ -177,18 +175,18 @@ def sync_references_with_filesystem(
for aid, acc in by_asset.items():
a_hash = acc["hash"]
refs = acc["refs"]
any_fast_ok = any(r["fast_ok"] for r in refs)
any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not r["exists"] for r in refs)
for r in refs:
if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue
if r["fast_ok"]:
if r["stat_unchanged"]:
to_clear_missing.append(r["ref_id"])
if r["needs_verify"]:
to_clear_verify.append(r["ref_id"])
if not r["fast_ok"] and not r["needs_verify"]:
if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None:
@@ -197,10 +195,10 @@ def sync_references_with_filesystem(
else:
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
survivors.add(os.path.abspath(r["file_path"]))
continue
if any_fast_ok:
if any_unchanged:
for r in refs:
if not r["exists"]:
stale_ref_ids.append(r["ref_id"])
@@ -219,7 +217,7 @@ def sync_references_with_filesystem(
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
@@ -349,58 +347,6 @@ def build_asset_specs(
return specs, tag_pool, skipped
def build_stub_specs(
paths: list[str],
existing_paths: set[str],
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build minimal stub specs for fast phase scanning.
Only collects filesystem metadata (stat), no file content reading.
This is the fastest possible scan to populate the asset database.
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
Returns:
Tuple of (specs, tag_pool, skipped_count)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": None,
"hash": None,
"mime_type": None,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs."""
@@ -538,7 +484,8 @@ def enrich_asset(
try:
digest = compute_blake3_hash(file_path)
full_hash = f"blake3:{digest}"
if not extract_metadata or metadata:
metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)

View File

@@ -12,7 +12,7 @@ from app.assets.scanner import (
ENRICHMENT_METADATA,
ENRICHMENT_STUB,
RootType,
build_stub_specs,
build_asset_specs,
collect_paths_for_roots,
enrich_assets_batch,
get_all_known_prefixes,
@@ -68,35 +68,23 @@ class ScanStatus:
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
class _AssetSeeder:
"""Background asset scanning manager.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._pause_event = threading.Event()
self._pause_event.set() # Start unpaused (set = running, clear = paused)
self._run_gate = threading.Event()
self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False
@@ -154,10 +142,10 @@ class AssetSeeder:
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._pause_event.set() # Ensure unpaused when starting
self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
name="_AssetSeeder",
daemon=True,
)
self._thread.start()
@@ -223,7 +211,7 @@ class AssetSeeder:
logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING
self._cancel_event.set()
self._pause_event.set() # Unblock if paused so thread can exit
self._run_gate.set() # Unblock if paused so thread can exit
return True
def stop(self) -> bool:
@@ -247,7 +235,7 @@ class AssetSeeder:
return False
logging.info("Asset seeder pausing")
self._state = State.PAUSED
self._pause_event.clear()
self._run_gate.clear()
return True
def resume(self) -> bool:
@@ -263,7 +251,7 @@ class AssetSeeder:
return False
logging.info("Asset seeder resuming")
self._state = State.RUNNING
self._pause_event.set()
self._run_gate.set()
self._emit_event("assets.seed.resumed", {})
return True
@@ -356,10 +344,10 @@ class AssetSeeder:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark cache states as missing when outside all known root prefixes.
"""Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but cache states are flagged as missing.
metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
@@ -369,7 +357,7 @@ class AssetSeeder:
a full scan of all roots or during maintenance.
Returns:
Number of cache states marked as missing
Number of references marked as missing
Raises:
ScanInProgressError: If a scan is currently running
@@ -389,7 +377,7 @@ class AssetSeeder:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing", marked)
logging.info("Marked %d references as missing", marked)
return marked
finally:
with self._lock:
@@ -409,9 +397,9 @@ class AssetSeeder:
Returns:
True if scan should stop, False to continue
"""
if not self._pause_event.is_set():
if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {})
self._pause_event.wait() # Blocks if paused
self._run_gate.wait() # Blocks if paused
return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None:
@@ -539,7 +527,11 @@ class AssetSeeder:
cancelled = True
return
total_enriched = self._run_enrich_phase(roots)
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event(
"assets.seed.enrich_complete",
@@ -613,7 +605,9 @@ class AssetSeeder:
)
# Use stub specs (no metadata extraction, no hashing)
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, enable_metadata_extraction=False, compute_hashes=False,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
@@ -661,11 +655,11 @@ class AssetSeeder:
self._update_progress(scanned=len(specs), created=total_created)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int:
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes.
Returns:
Total number of assets enriched
Tuple of (cancelled, total_enriched)
"""
total_enriched = 0
batch_size = 100
@@ -690,7 +684,7 @@ class AssetSeeder:
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
break
return True, total_enriched
# Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots(
@@ -737,7 +731,7 @@ class AssetSeeder:
)
last_progress_time = now
return total_enriched
return False, total_enriched
asset_seeder = AssetSeeder()
asset_seeder = _AssetSeeder()

View File

@@ -1,4 +1,3 @@
import asyncio
import os
from typing import IO
@@ -18,20 +17,6 @@ def compute_blake3_hash(
return _hash_file_obj(f, chunk_size)
async def compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK

View File

@@ -2,17 +2,16 @@ import contextlib
import logging
import mimetypes
import os
from typing import Sequence
from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetReference, Tag
from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
@@ -21,11 +20,13 @@ from app.assets.database.queries import (
set_reference_tags,
upsert_asset,
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_reference,
compute_relative_filename,
resolve_destination_from_tags,
validate_path_within_base,
)
@@ -55,6 +56,7 @@ def _ingest_file_from_path(
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False
asset_updated = False
@@ -64,7 +66,7 @@ def _ingest_file_from_path(
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
if preview_id not in get_existing_asset_ids(session, [preview_id]):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
@@ -94,7 +96,7 @@ def _ingest_file_from_path(
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
_validate_tags_exist(session, norm)
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
@@ -106,7 +108,8 @@ def _ingest_file_from_path(
_update_metadata_with_filename(
session,
reference_id=reference_id,
ref=ref,
file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata,
)
@@ -134,6 +137,8 @@ def _register_existing_asset(
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
@@ -157,7 +162,7 @@ def _register_existing_asset(
session.commit()
return result
new_meta = dict(user_metadata or {})
new_meta = dict(user_metadata)
computed_filename = compute_filename_for_reference(session, ref)
if computed_filename:
new_meta["filename"] = computed_filename
@@ -190,29 +195,20 @@ def _register_existing_asset(
return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename(
session: Session,
reference_id: str,
ref: AssetReference,
user_metadata: UserMetadata,
file_path: str | None,
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_filename_for_reference(session, ref)
computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = ref.user_metadata or {}
current_meta = current_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
new_meta[k] = v
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename

View File

@@ -51,8 +51,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", ".."):
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
@@ -113,6 +114,8 @@ def get_asset_category_and_relative_path(
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)