Reduce duplication across assets module

- Extract validate_blake3_hash() into helpers.py, used by upload, schemas, routes
- Extract get_reference_with_owner_check() into queries, used by 4 service functions
- Extract build_prefix_like_conditions() into queries/common.py, used by 3 queries
- Replace 3 inlined tag queries with get_reference_tags() calls
- Consolidate AddTagsDict/RemoveTagsDict TypedDicts into AddTagsResult/RemoveTagsResult
  dataclasses, eliminating manual field copying in tagging.py
- Make iter_row_chunks delegate to iter_chunks
- Inline trivial compute_filename_for_reference wrapper (unused session param)
- Remove mark_assets_missing_outside_prefixes pass-through in bulk_ingest.py
- Clean up unused imports (os, time, dependencies_available)
- Disable assets routes on DB init failure in main.py

Amp-Thread-ID: https://ampcode.com/threads/T-019cb649-dd4e-71ff-9a0e-ae517365207b
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-03 17:23:32 -08:00
parent e59fbc101d
commit bfdb78da05
18 changed files with 164 additions and 230 deletions

View File

@@ -24,6 +24,7 @@ from app.assets.database.queries.asset_reference import (
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_with_owner_check,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
@@ -44,9 +45,9 @@ from app.assets.database.queries.asset_reference import (
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
AddTagsResult,
RemoveTagsResult,
SetTagsResult,
add_missing_tag_for_asset_id,
add_tags_to_reference,
bulk_insert_tags_and_meta,
@@ -60,10 +61,10 @@ from app.assets.database.queries.tags import (
)
__all__ = [
"AddTagsDict",
"AddTagsResult",
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"RemoveTagsResult",
"SetTagsResult",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_reference",
@@ -87,6 +88,7 @@ __all__ = [
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_with_owner_check",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",

View File

@@ -4,7 +4,6 @@ This module replaces the separate asset_info.py and cache_state.py query modules
providing a unified interface for the merged asset_references table.
"""
import os
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
@@ -25,6 +24,7 @@ from app.assets.database.models import (
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_prefix_like_conditions,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
@@ -165,6 +165,25 @@ def get_reference_by_id(
return session.get(AssetReference, reference_id)
def get_reference_with_owner_check(
session: Session,
reference_id: str,
owner_id: str,
) -> AssetReference:
"""Fetch a reference and verify ownership.
Raises:
ValueError: if reference not found
PermissionError: if owner_id doesn't match
"""
ref = get_reference_by_id(session, reference_id=reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
if ref.owner_id and ref.owner_id != owner_id:
raise PermissionError("not owner")
return ref
def get_reference_by_file_path(
session: Session,
file_path: str,
@@ -636,12 +655,8 @@ def mark_references_missing_outside_prefixes(
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetReference.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
conds = build_prefix_like_conditions(valid_prefixes)
matches_valid_prefix = sa.or_(*conds)
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.isnot(None))
@@ -729,13 +744,7 @@ def get_references_for_prefixes(
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
conds = build_prefix_like_conditions(prefixes)
query = (
sa.select(
@@ -875,13 +884,7 @@ def get_unenriched_references(
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
conds = build_prefix_like_conditions(prefixes)
query = (
sa.select(

View File

@@ -1,10 +1,12 @@
"""Shared utilities for database query modules."""
import os
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetReference
from app.assets.helpers import escape_sql_like_string
MAX_BIND_PARAMS = 800
@@ -24,9 +26,7 @@ def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
rows_per_stmt = calculate_rows_per_statement(cols_per_row)
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
@@ -38,3 +38,17 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
if owner_id == "":
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])
def build_prefix_like_conditions(
prefixes: list[str],
) -> list[sa.sql.ColumnElement]:
"""Build LIKE conditions for matching file paths under directory prefixes."""
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds

View File

@@ -1,4 +1,5 @@
from typing import Iterable, Sequence, TypedDict
from dataclasses import dataclass
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
@@ -19,19 +20,22 @@ from app.assets.database.queries.common import (
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
@@ -81,19 +85,10 @@ def set_reference_tags(
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
) -> SetTagsResult:
desired = normalize_tags(tags)
current = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
)
current = set(get_reference_tags(session, reference_id))
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
@@ -122,7 +117,7 @@ def set_reference_tags(
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
def add_tags_to_reference(
@@ -132,7 +127,7 @@ def add_tags_to_reference(
origin: str = "manual",
create_if_missing: bool = True,
reference_row: AssetReference | None = None,
) -> AddTagsDict:
) -> AddTagsResult:
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
@@ -141,21 +136,12 @@ def add_tags_to_reference(
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return {"added": [], "already_present": [], "total_tags": total}
return AddTagsResult(added=[], already_present=[], total_tags=total)
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
}
current = set(get_reference_tags(session, reference_id))
want = set(norm)
to_add = sorted(want - current)
@@ -179,18 +165,18 @@ def add_tags_to_reference(
nested.rollback()
after = set(get_reference_tags(session, reference_id=reference_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
return AddTagsResult(
added=sorted(((after - current) & want)),
already_present=sorted(want & current),
total_tags=sorted(after),
)
def remove_tags_from_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
) -> RemoveTagsResult:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
@@ -198,18 +184,9 @@ def remove_tags_from_reference(
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": [], "not_present": [], "total_tags": total}
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
}
existing = set(get_reference_tags(session, reference_id))
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
@@ -224,7 +201,7 @@ def remove_tags_from_reference(
session.flush()
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
def add_missing_tag_for_asset_id(