Compare commits

...

35 Commits

Author SHA1 Message Date
Luke Mino-Altherr
fa3749ced7 Add TypedDict types to scanner and bulk_ingest
Amp-Thread-ID: https://ampcode.com/threads/T-019c2af9-4d41-73e9-b38d-78d06bc28a3f
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:32:26 -08:00
Luke Mino-Altherr
16b5d9112b Fix path traversal validation to return 400 instead of 500
Catch ValueError from resolve_destination_from_tags in the upload
endpoint so that invalid path components like '..' return a 400
BAD_REQUEST error instead of falling through to the 500 handler.

Amp-Thread-ID: https://ampcode.com/threads/T-019c2af2-7c87-7263-88b0-9feca1c31b3c
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:24:51 -08:00
Luke Mino-Altherr
abeec3072b refactor(assets): extract scanner logic into service modules
- Create file_utils.py with shared file utilities:
  - get_mtime_ns() - extract mtime in nanoseconds from stat
  - get_size_and_mtime_ns() - get both size and mtime
  - verify_file_unchanged() - check file matches DB mtime/size
  - list_files_recursively() - recursive directory listing

- Create bulk_ingest.py for bulk operations:
  - BulkInsertResult dataclass
  - batch_insert_seed_assets() - batch insert with conflict handling
  - prune_orphaned_assets() - clean up orphaned assets

- Update scanner.py to use new service modules instead of
  calling database queries directly

- Update ingest.py to use shared get_size_and_mtime_ns()

- Export new functions from services/__init__.py

Amp-Thread-ID: https://ampcode.com/threads/T-019c2ae7-f701-716a-a0dd-1feb988732fb
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:17:31 -08:00
Luke Mino-Altherr
b23302f372 refactor(assets): consolidate duplicated query utilities and remove unused code
- Extract shared helpers to database/queries/common.py:
  - MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks, iter_row_chunks
  - build_visible_owner_clause

- Remove duplicate _compute_filename_for_asset, consolidate in path_utils.py

- Remove unused get_asset_info_with_tags (duplicated get_asset_detail)

- Remove redundant __all__ from cache_state.py

- Make internal helpers private (_check_is_scalar)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2ad9-9432-7451-94a8-79287dbbb19e
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:04:30 -08:00
Luke Mino-Altherr
adf6eb73fd refactor: eliminate manager layer, routes call services directly
- Delete app/assets/manager.py
- Move upload logic (upload_from_temp_path, create_from_hash) to ingest service
- Add HashMismatchError and DependencyMissingError to ingest service
- Add UploadResult schema for upload responses
- Update routes.py to import services directly and do schema conversion inline
- Add asset lookup/listing service functions to asset_management.py

Routes now call the service layer directly, removing an unnecessary
layer of indirection. The manager was only converting between service
dataclasses and Pydantic response schemas.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 14:50:11 -08:00
Luke Mino-Altherr
5259959fef refactor: require blake3 package directly in hashing module
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:42:11 -08:00
Luke Mino-Altherr
5474d8bf84 chore: consolidate service imports in manager.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:36:06 -08:00
Luke Mino-Altherr
9290e26e9f refactor: add explicit types to asset service functions
- Add typed result dataclasses: IngestResult, AddTagsResult,
  RemoveTagsResult, SetTagsResult, TagUsage
- Add UserMetadata type alias for user_metadata parameters
- Type helper functions with Session parameters
- Use TypedDicts at query layer to avoid circular imports
- Update manager.py and tests to use attribute access

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:32:14 -08:00
Luke Mino-Altherr
37ecc5b663 chore: remove obvious/self-documenting comments from assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:14:11 -08:00
Luke Mino-Altherr
80d99e7b63 chore: remove module-level comments and docstrings from assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:04:47 -08:00
Luke Mino-Altherr
d8cb122dfb chore: sort imports in assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:02:52 -08:00
Luke Mino-Altherr
0f75def5b5 refactor: move scanner.py out of services to top-level assets module
Scanner is used externally by main.py and server.py for startup/maintenance,
not as part of the regular service layer. Moving it to app/assets/scanner.py
makes the public API clearer.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:56:29 -08:00
Luke Mino-Altherr
6b1f9f7755 refactor: convert asset tests to table-driven parametrized tests
- test_metadata.py: consolidate 7 filter type classes into parametrized tests
- test_asset.py: parametrize exists, get, and upsert test cases
- test_cache_state.py: parametrize upsert and delete scenarios
- test_crud.py: consolidate error response tests into single parametrized test
- test_list_filter.py: consolidate invalid query tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:50:16 -08:00
Luke Mino-Altherr
3311b13740 chore: remove unused re-exports from conftest.py
The helper functions are already imported directly from helpers.py
by all test files, so the backwards compatibility re-export is dead code.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:22:24 -08:00
Luke Mino-Altherr
bf7fbb6317 chore: remove unused get_utc_now import
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:47:11 -08:00
Luke Mino-Altherr
5571508e61 refactor: use query functions instead of direct ORM modifications in service layer
Add update_asset_info_name and update_asset_info_updated_at query functions
and update asset_management.py to use them instead of modifying ORM objects
directly. This ensures the service layer only uses explicit operations from
the queries package.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:44:23 -08:00
Luke Mino-Altherr
e3b8e512ca refactor: use explicit dataclasses instead of ORM objects in service layer
Replace dict/ORM object returns with explicit dataclasses to fix
DetachedInstanceError when accessing ORM attributes after session closes.

- Add app/assets/services/schemas.py with AssetData, AssetInfoData,
  AssetDetailResult, and RegisterAssetResult dataclasses
- Update asset_management.py and ingest.py to return dataclasses
- Update manager.py to use attribute access on dataclasses
- Fix created_new to be False in create_asset_from_hash (content exists)
- Add DependencyMissingError for better blake3 missing error handling
- Update tests to use attribute access instead of dict subscripting

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:39:07 -08:00
Luke Mino-Altherr
ea01cd665d fix: resolve test import errors and module collision in assets_test
Extract helper functions from conftest.py to a dedicated helpers.py module
to fix import resolution issues when pytest processes subdirectories.
Rename test_tags.py to test_tags_api.py to avoid module name collision
with queries/test_tags.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:57:52 -08:00
Luke Mino-Altherr
ccfc5dedd4 fix: handle missing blake3 module gracefully to prevent server crash
Make blake3 an optional import that fails gracefully at import time,
with a clear error message when hashing functions are actually called.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:46:33 -08:00
Luke Mino-Altherr
e9ca190098 refactor: remove try-finally wrapper in seed_assets by extracting helpers
Extract focused helper functions to eliminate the try-finally block that
wrapped ~50 lines just for logging. The new helpers (_collect_paths_for_roots,
_build_asset_specs, _insert_asset_specs) make seed_assets a simple linear flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:33:54 -08:00
Luke Mino-Altherr
ed60e93696 refactor: flatten nested try blocks and if statements in assets package
Extract helper functions to eliminate nested try-except blocks in scanner.py
and remove duplicated type-checking logic in asset_info.py. Simplify nested
conditionals in asset_management.py for clearer control flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:28:16 -08:00
Luke Mino-Altherr
fef2f01671 refactor: improve function naming for clarity and consistency
Rename functions to use clearer verb-based names:
- pick_best_live_path → select_best_live_path
- escape_like_prefix → escape_sql_like_string
- list_tree → list_files_recursively
- check_asset_file_fast → verify_asset_file_unchanged
- _seed_from_paths_batch → _batch_insert_assets_from_paths
- reconcile_cache_states_for_root → sync_cache_states_with_filesystem
- touch_asset_info_by_id → update_asset_info_access_time
- replace_asset_info_metadata_projection → set_asset_info_metadata
- expand_metadata_to_rows → convert_metadata_to_rows
- _rows_per_stmt → _calculate_rows_per_statement
- ensure_within_base → validate_path_within_base
- _cleanup_temp → _delete_temp_file_if_exists
- validate_hash_format → normalize_and_validate_hash
- get_relative_to_root_category_path_of_asset → get_asset_category_and_relative_path

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:20:36 -08:00
Luke Mino-Altherr
481a2fa263 refactor: rename functions to verb-based naming convention
Rename functions across app/assets/ to follow verb-based naming:
- is_scalar → check_is_scalar
- project_kv → expand_metadata_to_rows
- _visible_owner_clause → _build_visible_owner_clause
- _chunk_rows → _iter_row_chunks
- _at_least_one → _validate_at_least_one_field
- _tags_norm → _normalize_tags_field
- _ser_dt → _serialize_datetime
- _ser_updated → _serialize_updated_at
- _error_response → _build_error_response
- _validation_error_response → _build_validation_error_response
- file_sender → stream_file_chunks
- seed_assets_endpoint → seed_assets
- utcnow → get_utc_now
- _safe_sort_field → _validate_sort_field
- _safe_filename → _sanitize_filename
- fast_asset_file_check → check_asset_file_fast
- prefixes_for_root → get_prefixes_for_root
- blake3_hash → compute_blake3_hash
- blake3_hash_async → compute_blake3_hash_async
- _is_within → _check_is_within
- _rel → _compute_relative

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:58:14 -08:00
Luke Mino-Altherr
11ca1995a3 fix: remaining ruff linting errors in services tests
- Remove unused os imports in conftest.py and test_ingest.py
- Remove unused Tag import in test_asset_management.py
- Remove unused ensure_tags_exist import in test_ingest.py
- Fix unused info2 variable in test_asset_management.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:31:38 -08:00
Luke Mino-Altherr
4e02245012 fix: ruff linting errors and add comprehensive test coverage for asset queries
- Fix unused imports in routes.py, asset.py, manager.py, asset_management.py, ingest.py
- Fix whitespace issues in upload.py, asset_info.py, ingest.py
- Fix typo in manager.py (stray character after result["asset"])
- Fix broken import in test_metadata.py (project_kv moved to asset_info.py)
- Add fixture override in queries/conftest.py for unit test isolation

Add 48 new tests covering all previously untested query functions:
- asset.py: upsert_asset, bulk_insert_assets
- cache_state.py: upsert_cache_state, delete_cache_states_outside_prefixes,
  get_orphaned_seed_asset_ids, delete_assets_by_ids, get_cache_states_for_prefixes,
  bulk_set_needs_verify, delete_cache_states_by_ids, delete_orphaned_seed_asset,
  bulk_insert_cache_states_ignore_conflicts, get_cache_states_by_paths_and_asset_ids
- asset_info.py: insert_asset_info, get_or_create_asset_info,
  update_asset_info_timestamps, replace_asset_info_metadata_projection,
  bulk_insert_asset_infos_ignore_conflicts, get_asset_info_ids_by_ids
- tags.py: bulk_insert_tags_and_meta

Total: 119 tests pass (up from 71)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:21:12 -08:00
Luke Mino-Altherr
9f9db2c2c2 refactor: extract multipart upload parsing from routes
- Add app/assets/api/upload.py with parse_multipart_upload() for HTTP parsing
- Add ParsedUpload dataclass to schemas_in.py
- Add domain exceptions (AssetValidationError, AssetNotFoundError, HashMismatchError)
- Add manager.process_upload() with domain exceptions (no HTTP status codes)
- Routes map domain exceptions to HTTP responses
- Slim down upload_asset route to ~20 lines (was ~150)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2519-abe1-738a-ad2e-29ece17c0e42
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
e987bd268f Move get_comfy_models_folders to path_utils.py to avoid late import
Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
2eb100adf9 Refactor helpers.py: move functions to their respective modules
- Move scanner-only functions to scanner.py
- Move query-only functions (is_scalar, project_kv) to asset_info.py
- Move get_query_dict to routes.py
- Create path_utils.py service for path-related functions
- Reduce helpers.py to shared utilities only

Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
a02f160e20 Move hashing.py to services directory
Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
c3105b1174 refactor: move bulk_ops to queries and scanner service
- Delete bulk_ops.py, moving logic to appropriate layers
- Add bulk insert query functions:
  - queries/asset.bulk_insert_assets
  - queries/cache_state.bulk_insert_cache_states_ignore_conflicts
  - queries/cache_state.get_cache_states_by_paths_and_asset_ids
  - queries/asset_info.bulk_insert_asset_infos_ignore_conflicts
  - queries/asset_info.get_asset_info_ids_by_ids
  - queries/tags.bulk_insert_tags_and_meta
- Move seed_from_paths_batch orchestration to scanner._seed_from_paths_batch

Amp-Thread-ID: https://ampcode.com/threads/T-019c24fd-157d-776a-ad24-4f19cf5d3afe
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
64d2f51dfc refactor: move scanner to services layer with pure query extraction
- Move app/assets/scanner.py to app/assets/services/scanner.py
- Extract pure queries from fast_db_consistency_pass:
  - get_cache_states_for_prefixes()
  - bulk_set_needs_verify()
  - delete_cache_states_by_ids()
  - delete_orphaned_seed_asset()
- Split prune_orphaned_assets into pure queries:
  - delete_cache_states_outside_prefixes()
  - get_orphaned_seed_asset_ids()
  - delete_assets_by_ids()
- Add reconcile_cache_states_for_root() service function
- Add prune_orphaned_assets() service function
- Remove function injection pattern
- Update imports in main.py, server.py, routes.py

Amp-Thread-ID: https://ampcode.com/threads/T-019c24f1-3385-701b-87e0-8b6bc87e841b
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
fba4570e49 refactor: move in-function imports to top-level and remove keyword-only argument pattern
- Move imports from inside functions to module top-level in:
  - app/assets/database/queries/asset.py
  - app/assets/database/queries/asset_info.py
  - app/assets/database/queries/cache_state.py
  - app/assets/manager.py
  - app/assets/services/asset_management.py
  - app/assets/services/ingest.py

- Remove keyword-only argument markers (*,) from app/assets/ to match codebase conventions

Amp-Thread-ID: https://ampcode.com/threads/T-019c24eb-bfa2-727f-8212-8bc976048604
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
15ee03f65c Refactor asset database: separate business logic from queries
Architecture changes:
- API Routes -> manager.py (thin adapter) -> services/ (business logic) -> queries/ (atomic DB ops)
- Services own session lifecycle via create_session()
- Queries accept Session as parameter, do single-table atomic operations

New app/assets/services/ layer:
- __init__.py - exports all service functions
- ingest.py - ingest_file_from_path(), register_existing_asset()
- asset_management.py - get_asset_detail(), update_asset_metadata(), delete_asset_reference(), set_asset_preview()
- tagging.py - apply_tags(), remove_tags(), list_tags()

Removed from queries/asset_info.py:
- ingest_fs_asset (moved to services/ingest.py as ingest_file_from_path)
- update_asset_info_full (moved to services/asset_management.py as update_asset_metadata)
- create_asset_info_for_existing_asset (moved to services/ingest.py as register_existing_asset)

Updated manager.py:
- Now a thin adapter that transforms API schemas to/from service calls
- Delegates all business logic to services layer
- No longer imports sqlalchemy.orm.Session or models directly

Test updates:
- Fixed test_cache_state.py import of pick_best_live_path (moved to helpers.py)
- Added comprehensive service layer tests (41 new tests)
- All 112 query + service tests pass

Amp-Thread-ID: https://ampcode.com/threads/T-019c24e2-7ae4-707f-ad19-c775ed8b82b5
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
70a600baf0 chore: remove unused Asset import from manager.py
Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
17ad7e393f refactor(assets): split queries.py into modular query modules
Split the ~1000 line app/assets/database/queries.py into focused modules:

- queries/asset.py - Asset entity queries (asset_exists_by_hash, get_asset_by_hash)
- queries/asset_info.py - AssetInfo queries (~15 functions)
- queries/cache_state.py - AssetCacheState queries (list_cache_states_by_asset_id,
  pick_best_live_path, prune_orphaned_assets, fast_db_consistency_pass)
- queries/tags.py - Tag queries (8 functions including ensure_tags_exist,
  add/remove tag functions, list_tags_with_usage)
- queries/__init__.py - Re-exports all public functions for backward compatibility

Also adds comprehensive unit tests using in-memory SQLite:
- tests-unit/assets_test/queries/conftest.py - Session fixture
- tests-unit/assets_test/queries/test_asset.py - 5 tests
- tests-unit/assets_test/queries/test_asset_info.py - 23 tests
- tests-unit/assets_test/queries/test_cache_state.py - 8 tests
- tests-unit/assets_test/queries/test_metadata.py - 12 tests for _apply_metadata_filter
- tests-unit/assets_test/queries/test_tags.py - 23 tests

All 71 unit tests pass. Existing integration tests unaffected.

Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
46 changed files with 6278 additions and 2705 deletions

View File

@@ -1,19 +1,36 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
import urllib.parse
import uuid
from typing import Any
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.api.schemas_in import (
AssetValidationError,
UploadError,
)
from app.assets.api.upload import parse_multipart_upload
from app.assets.scanner import seed_assets as scanner_seed_assets
from app.assets.services import (
DependencyMissingError,
HashMismatchError,
apply_tags,
asset_exists,
create_from_hash,
delete_asset_reference,
get_asset_detail,
list_assets_page,
list_tags,
remove_tags,
resolve_asset_for_download,
update_asset_metadata,
upload_from_temp_path,
)
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -21,36 +38,78 @@ USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key)
if len(request.query.getall(key)) > 1
else request.query.get(key)
for key in request.query.keys()
}
return query_dict
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
def register_assets_system(
app: web.Application, user_manager_instance: user_manager.UserManager
) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _build_error_response(
status: int, code: str, message: str, details: dict | None = None
) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _build_error_response(400, code, "Validation failed.", {"errors": ve.json()})
def _validate_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
exists = asset_exists(hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
async def list_assets_route(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
@@ -58,66 +117,124 @@ async def list_assets(request: web.Request) -> web.Response:
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
return _build_validation_error_response("INVALID_QUERY", ve)
payload = manager.list_assets(
sort = _validate_sort_field(q.sort)
order = (
"desc"
if (q.order or "desc").lower() not in {"asc", "desc"}
else q.order.lower()
)
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
sort=sort,
order=order,
)
summaries = [
schemas_out.AssetSummary(
id=item.info.id,
name=item.info.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes)
if item.asset and item.asset.size_bytes
else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.info.created_at,
updated_at=item.info.updated_at,
last_access_time=item.info.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=(q.offset + len(summaries)) < result.total,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
async def get_asset_route(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = manager.get_asset(
result = get_asset_detail(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if not result:
return _build_error_response(
404,
"ASSET_NOT_FOUND",
f"AssetInfo {asset_info_id} not found",
{"id": asset_info_id},
)
payload = schemas_out.AssetDetail(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes)
if result.asset and result.asset.size_bytes is not None
else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
)
except ValueError as e:
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}
)
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
result = resolve_asset_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
abs_path = result.abs_path
content_type = result.content_type
filename = result.download_name
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
return _build_error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
return _build_error_response(
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(filename)}"
file_size = os.path.getsize(abs_path)
logging.info(
@@ -129,7 +246,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
filename,
)
async def file_sender():
async def stream_file_chunks():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
@@ -139,7 +256,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
yield chunk
return web.Response(
body=file_sender(),
body=stream_file_chunks(),
content_type=content_type,
headers={
"Content-Disposition": cd,
@@ -149,16 +266,18 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
return _build_validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
result = manager.create_asset_from_hash(
result = create_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
@@ -166,228 +285,191 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
payload_out = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes) if result.asset.size_bytes else None,
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
def _delete_temp_file_if_exists(path: str | None) -> None:
if path and os.path.exists(path):
try:
os.remove(path)
except Exception:
pass
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists)
except UploadError as e:
return _build_error_response(e.status, e.code, e.message)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
try:
spec = schemas_in.UploadAssetSpec.model_validate(
{
"tags": parsed.tags_raw,
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
}
)
except ValidationError as ve:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
)
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
_delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try:
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist"
)
_delete_temp_file_if_exists(parsed.tmp_path)
else:
# Otherwise, we must have a temp file path to ingest
if not parsed.tmp_path or not os.path.exists(parsed.tmp_path):
return _build_error_response(
404,
"ASSET_NOT_FOUND",
"Provided hash not found and no file uploaded.",
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
result = upload_from_temp_path(
temp_path=parsed.tmp_path,
name=spec.name,
tags=spec.tags,
user_metadata=spec.user_metadata or {},
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except AssetValidationError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, e.code, str(e))
except ValueError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "BAD_REQUEST", str(e))
except HashMismatchError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e))
except DependencyMissingError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(503, "DEPENDENCY_MISSING", e.message)
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
_delete_temp_file_if_exists(parsed.tmp_path)
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes) if result.asset.size_bytes else None,
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
async def update_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
return _build_validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.update_asset(
result = update_asset_metadata(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.AssetUpdated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
updated_at=result.info.updated_at,
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
async def delete_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
delete_content_param = request.query.get("delete_content")
delete_content = (
True
if delete_content_param is None
else delete_content_param.lower() not in {"0", "false", "no"}
)
try:
deleted = manager.delete_asset_reference(
deleted = delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
@@ -398,10 +480,12 @@ async def delete_asset(request: web.Request) -> web.Response:
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return _build_error_response(
404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found."
)
return web.Response(status=204)
@@ -416,11 +500,17 @@ async def get_tags(request: web.Request) -> web.Response:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
{
"error": {
"code": "INVALID_QUERY",
"message": "Invalid query parameters",
"details": e.errors(),
}
},
status=400,
)
result = manager.list_tags(
rows, total = list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
@@ -428,72 +518,108 @@ async def get_tags(request: web.Request) -> web.Response:
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
tags = [
schemas_out.TagUsage(name=name, count=count, type=tag_type)
for (name, tag_type, count) in rows
]
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
json_payload = await request.json()
data = schemas_in.TagsAdd.model_validate(json_payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags add.",
{"errors": ve.errors()},
)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.add_tags_to_asset(
result = apply_tags(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
json_payload = await request.json()
data = schemas_in.TagsRemove.model_validate(json_payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags remove.",
{"errors": ve.errors()},
)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.remove_tags_from_asset(
result = remove_tags(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets_endpoint(request: web.Request) -> web.Response:
async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
try:
payload = await request.json()
@@ -503,12 +629,12 @@ async def seed_assets_endpoint(request: web.Request) -> web.Response:
valid_roots = [r for r in roots if r in ("models", "input", "output")]
if not valid_roots:
return _error_response(400, "INVALID_BODY", "No valid roots specified")
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
try:
seed_assets(tuple(valid_roots))
scanner_seed_assets(tuple(valid_roots))
except Exception:
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
logging.exception("scanner_seed_assets failed for roots=%s", valid_roots)
return _build_error_response(500, "INTERNAL", "Seed operation failed")
return web.json_response({"seeded": valid_roots}, status=200)

View File

@@ -1,4 +1,5 @@
import json
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import (
@@ -10,6 +11,61 @@ from pydantic import (
model_validator,
)
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code (used in HTTP layer only)."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
class AssetNotFoundError(Exception):
"""Asset or asset content not found."""
def __init__(self, message: str):
super().__init__(message)
class HashMismatchError(Exception):
"""Uploaded file hash does not match provided hash."""
pass
class DependencyMissingError(Exception):
"""A required dependency is not installed."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@@ -21,7 +77,9 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@@ -61,7 +119,7 @@ class UpdateAssetBody(BaseModel):
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -90,7 +148,7 @@ class CreateFromHashBody(BaseModel):
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
def _normalize_tags_field(cls, v):
if v is None:
return []
if isinstance(v, list):
@@ -163,6 +221,7 @@ class UploadAssetSpec(BaseModel):
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
@@ -260,5 +319,7 @@ class UploadAssetSpec(BaseModel):
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None

165
app/assets/api/upload.py Normal file
View File

@@ -0,0 +1,165 @@
import os
import uuid
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
def normalize_and_validate_hash(s: str) -> str:
"""
Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
s = s.strip().lower()
if not s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if ":" not in s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
return f"{algo}:{digest}"
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: callable,
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
_delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
_delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
)
def _delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file if it exists."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception:
pass

View File

@@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@@ -20,19 +20,21 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
from app.assets.helpers import get_utc_now
from app.database.models import Base, to_dict
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
infos: Mapped[list[AssetInfo]] = relationship(
@@ -75,7 +77,9 @@ class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
@@ -85,7 +89,9 @@ class AssetCacheState(Base):
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"
),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
@@ -99,15 +105,29 @@ class AssetCacheState(Base):
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False
)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset: Mapped[Asset] = relationship(
"Asset",
@@ -143,7 +163,9 @@ class AssetInfo(Base):
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
UniqueConstraint(
"asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"
),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
@@ -196,7 +218,7 @@ class AssetInfoTag(Base):
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
@@ -225,9 +247,7 @@ class Tag(Base):
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -0,0 +1,99 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
upsert_asset,
)
from app.assets.database.queries.asset_info import (
asset_info_exists_for_asset_id,
bulk_insert_asset_infos_ignore_conflicts,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
get_asset_info_ids_by_ids,
get_or_create_asset_info,
insert_asset_info,
list_asset_infos_page,
set_asset_info_metadata,
set_asset_info_preview,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_timestamps,
update_asset_info_updated_at,
)
from app.assets.database.queries.cache_state import (
CacheStateRow,
bulk_insert_cache_states_ignore_conflicts,
bulk_set_needs_verify,
delete_assets_by_ids,
delete_cache_states_by_ids,
delete_cache_states_outside_prefixes,
delete_orphaned_seed_asset,
get_cache_states_by_paths_and_asset_ids,
get_cache_states_for_prefixes,
get_orphaned_seed_asset_ids,
list_cache_states_by_asset_id,
upsert_cache_state,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_asset_info,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_asset_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_asset_info,
set_asset_info_tags,
)
__all__ = [
"AddTagsDict",
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"add_missing_tag_for_asset_id",
"add_tags_to_asset_info",
"asset_exists_by_hash",
"asset_info_exists_for_asset_id",
"bulk_insert_asset_infos_ignore_conflicts",
"bulk_insert_assets",
"bulk_insert_cache_states_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_set_needs_verify",
"delete_asset_info_by_id",
"delete_assets_by_ids",
"delete_cache_states_by_ids",
"delete_cache_states_outside_prefixes",
"delete_orphaned_seed_asset",
"ensure_tags_exist",
"fetch_asset_info_and_asset",
"fetch_asset_info_asset_and_tags",
"get_asset_by_hash",
"get_asset_info_by_id",
"get_asset_info_ids_by_ids",
"get_asset_tags",
"get_cache_states_by_paths_and_asset_ids",
"get_cache_states_for_prefixes",
"get_or_create_asset_info",
"get_orphaned_seed_asset_ids",
"insert_asset_info",
"list_asset_infos_page",
"list_cache_states_by_asset_id",
"list_tags_with_usage",
"remove_missing_tag_for_asset_id",
"remove_tags_from_asset_info",
"set_asset_info_metadata",
"set_asset_info_preview",
"set_asset_info_tags",
"update_asset_info_access_time",
"update_asset_info_name",
"update_asset_info_timestamps",
"update_asset_info_updated_at",
"upsert_asset",
"upsert_cache_state",
]

View File

@@ -0,0 +1,90 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries.common import calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
if not rows:
return
ins = sqlite.insert(Asset)
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)

View File

@@ -0,0 +1,527 @@
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import (
Asset,
AssetInfo,
AssetInfoMeta,
AssetInfoTag,
Tag,
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return {"key": key, "ordinal": ordinal, "val_num": num}
if isinstance(value, str):
return {"key": key, "ordinal": ordinal, "val_str": value}
return {"key": key, "ordinal": ordinal, "val_json": value}
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
if value is None:
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_info_exists_for_asset_id(
session: Session,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_info_by_id(
session: Session,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def insert_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> AssetInfo | None:
"""Insert a new AssetInfo. Returns None if unique constraint violated."""
now = get_utc_now()
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset_id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
except IntegrityError:
return None
def get_or_create_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> tuple[AssetInfo, bool]:
"""Get existing or create new AssetInfo. Returns (info, created)."""
info = insert_asset_info(
session,
asset_id=asset_id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if info:
return info, True
existing = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset_id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
.unique()
.scalar_one_or_none()
)
if not existing:
raise RuntimeError("Failed to find AssetInfo after insert conflict.")
return existing, False
def update_asset_info_timestamps(
session: Session,
asset_info: AssetInfo,
preview_id: str | None = None,
) -> None:
"""Update timestamps and optionally preview_id on existing AssetInfo."""
now = get_utc_now()
if preview_id and asset_info.preview_id != preview_id:
asset_info.preview_id = preview_id
asset_info.updated_at = now
if asset_info.last_access_time < now:
asset_info.last_access_time = now
session.flush()
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def update_asset_info_access_time(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or get_utc_now()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts
)
)
session.execute(stmt.values(last_access_time=ts))
def update_asset_info_name(
session: Session,
asset_info_id: str,
name: str,
) -> None:
"""Update the name of an AssetInfo."""
now = get_utc_now()
session.execute(
sa.update(AssetInfo)
.where(AssetInfo.id == asset_info_id)
.values(name=name, updated_at=now)
)
def update_asset_info_updated_at(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
) -> None:
"""Update the updated_at timestamp of an AssetInfo."""
ts = ts or get_utc_now()
session.execute(
sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts)
)
def set_asset_info_metadata(
session: Session,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = get_utc_now()
session.flush()
session.execute(
delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)
)
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_asset_info_by_id(
session: Session,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def set_asset_info_preview(
session: Session,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = get_utc_now()
session.flush()
def bulk_insert_asset_infos_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
Each dict should have: id, owner_id, name, asset_id, preview_id,
user_metadata, created_at, updated_at, last_access_time
"""
if not rows:
return
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(9)):
session.execute(ins, chunk)
def get_asset_info_ids_by_ids(
session: Session,
info_ids: list[str],
) -> set[str]:
"""Query to find which AssetInfo IDs exist in the database."""
if not info_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS):
result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk)))
found.update(result.scalars().all())
return found

View File

@@ -0,0 +1,280 @@
import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated)."""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def delete_cache_states_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Delete cache states with file_path not matching any of the valid prefixes.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states deleted
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(sa.delete(AssetCacheState).where(~matches_valid_prefix))
return result.rowcount
def get_orphaned_seed_asset_ids(session: Session) -> list[str]:
"""Get IDs of seed assets (hash is None) with no remaining cache states.
Returns:
List of asset IDs that are orphaned
"""
orphan_subq = (
sa.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
)
return [row[0] for row in session.execute(orphan_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
rows = session.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_set_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
"""
if not rows:
return
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(3)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners

View File

@@ -0,0 +1,37 @@
"""Shared utilities for database query modules."""
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetInfo
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return []
rows_per_stmt = max(1, MAX_BIND_PARAMS // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -0,0 +1,349 @@
from typing import Iterable, Sequence, TypedDict
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
added: list[str]
removed: list[str]
total: list[str]
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
]
def set_asset_info_tags(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
desired = normalize_tags(tags)
current = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: AssetInfo | None = None,
) -> AddTagsDict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(
sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)
),
AssetInfoTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
index_elements=[
AssetInfoMeta.asset_info_id,
AssetInfoMeta.key,
AssetInfoMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -1,75 +0,0 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,52 +1,36 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
from typing import Literal, Sequence
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
def select_best_live_path(states: Sequence) -> str:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
alive = [
s
for s in states
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = (
"models",
"input",
"output",
)
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
@@ -54,173 +38,11 @@ def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
def get_utc_now() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
@@ -229,84 +51,3 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,516 +0,0 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@@ -1,263 +1,318 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import time
from typing import Literal, TypedDict
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
ensure_tags_exist,
get_cache_states_for_prefixes,
remove_missing_tag_for_asset_id,
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
batch_insert_seed_assets,
prune_orphaned_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session, dependencies_available
class _StateInfo(TypedDict):
sid: int
fp: str
exists: bool
fast_ok: bool
needs_verify: bool
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
states: list[_StateInfo]
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def sync_cache_states_with_filesystem(
session,
root: RootType,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Reconcile cache states with filesystem for a root.
- Toggle needs_verify per state using fast mtime/size check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
"""
prefixes = get_prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
rows = get_cache_states_for_prefixes(session, prefixes)
by_asset: dict[str, _AssetAccumulator] = {}
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "states": []}
by_asset[row.asset_id] = acc
fast_ok = False
try:
exists = True
fast_ok = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append(
{
"sid": row.state_id,
"fp": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": row.needs_verify,
}
)
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok:
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(session, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
delete_cache_states_by_ids(session, stale_state_ids)
bulk_set_needs_verify(session, to_set_verify, value=True)
bulk_set_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def _sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_cache_states_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def _prune_orphans_safely(prefixes: list[str]) -> int:
"""Prune orphaned assets outside the given prefixes.
Returns count pruned or 0 on failure.
"""
try:
with create_session() as sess:
count = prune_orphaned_assets(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
return 0
def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def _build_asset_specs(
paths: list[str],
existing_paths: set[str],
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count)."""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def _insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_infos
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
"""Scan the given roots and seed the assets into the database."""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
existing_paths: set[str] = set()
for r in roots:
existing_paths.update(_sync_root_safely(r))
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
all_prefixes = [os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r)]
orphans_pruned = _prune_orphans_safely(all_prefixes)
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
paths = _collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
created = _insert_asset_specs(specs, tag_pool)
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)

View File

@@ -0,0 +1,91 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
prune_orphaned_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_file_from_path,
register_existing_asset,
upload_from_temp_path,
)
from app.assets.services.schemas import (
AddTagsResult,
AssetData,
AssetDetailResult,
AssetInfoData,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
RegisterAssetResult,
RemoveTagsResult,
SetTagsResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetInfoData",
"AssetSummaryData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"SetTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"ingest_file_from_path",
"list_assets_page",
"list_files_recursively",
"list_tags",
"prune_orphaned_assets",
"register_existing_asset",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@@ -0,0 +1,290 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_asset_info_by_id,
list_asset_infos_page,
list_cache_states_by_asset_id,
set_asset_info_metadata,
set_asset_info_preview,
set_asset_info_tags,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_filename_for_asset
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def get_asset_detail(
asset_info_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
return None
info, asset, tags = result
return AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info.owner_id and info.owner_id != owner_id:
raise PermissionError("not owner")
touched = False
if name is not None and name != info.name:
update_asset_info_name(session, asset_info_id=asset_info_id, name=name)
touched = True
computed_filename = compute_filename_for_asset(session, info.asset_id)
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_asset_info_metadata(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
update_asset_info_updated_at(session, asset_info_id=asset_info_id)
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
info, asset, tag_list = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
asset_info_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - delete it and its files
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [
s.file_path for s in (states or []) if getattr(s, "file_path", None)
]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
result = fetch_asset_info_asset_and_tags(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
info, asset, tags = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for info in infos:
items.append(
AssetSummaryData(
info=extract_info_data(info),
asset=extract_asset_data(info.asset),
tags=tag_map.get(info.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_asset_for_download(
asset_info_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(states)
if not abs_path:
raise FileNotFoundError
update_asset_info_access_time(session, asset_info_id=asset_info_id)
session.commit()
ctype = (
asset.mime_type
or mimetypes.guess_type(info.name or abs_path)[0]
or "application/octet-stream"
)
download_name = info.name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@@ -0,0 +1,203 @@
import os
import uuid
from dataclasses import dataclass
from typing import TypedDict
from sqlalchemy.orm import Session
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
from app.assets.database.queries import (
bulk_insert_asset_infos_ignore_conflicts,
bulk_insert_assets,
bulk_insert_cache_states_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
delete_cache_states_outside_prefixes,
get_asset_info_ids_by_ids,
get_cache_states_by_paths_and_asset_ids,
get_orphaned_seed_asset_ids,
)
from app.assets.helpers import get_utc_now
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_infos: int
won_states: int
lost_states: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim cache states with ON CONFLICT DO NOTHING
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert AssetInfo for winners
6. Insert tags and metadata for successfully inserted AssetInfos
Returns:
BulkInsertResult with inserted_infos, won_states, lost_states
"""
if not specs:
return BulkInsertResult(inserted_infos=0, won_states=0, lost_states=0)
now = get_utc_now()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {}
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
bulk_insert_assets(session, asset_rows)
bulk_insert_cache_states_ignore_conflicts(session, state_rows)
winners_by_path = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets:
delete_assets_by_ids(session, lost_assets)
if not winners_by_path:
return BulkInsertResult(
inserted_infos=0,
won_states=0,
lost_states=len(losers_by_path),
)
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
db_info_rows = [
{
"id": row["id"],
"owner_id": row["owner_id"],
"name": row["name"],
"asset_id": row["asset_id"],
"preview_id": row["preview_id"],
"user_metadata": row["user_metadata"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_access_time": row["last_access_time"],
}
for row in winner_info_rows
]
bulk_insert_asset_infos_ignore_conflicts(session, db_info_rows)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append(
{
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
}
)
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows)
return BulkInsertResult(
inserted_infos=len(inserted_info_ids),
won_states=len(winners_by_path),
lost_states=len(losers_by_path),
)
def prune_orphaned_assets(session: Session, valid_prefixes: list[str]) -> int:
"""Prune cache states outside valid prefixes, then delete orphaned seed assets.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of orphaned assets deleted
"""
delete_cache_states_outside_prefixes(session, valid_prefixes)
orphan_ids = get_orphaned_seed_asset_ids(session)
return delete_assets_by_ids(session, orphan_ids)

View File

@@ -0,0 +1,49 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=False
):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@@ -0,0 +1,54 @@
import asyncio
import os
from typing import IO
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def compute_compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = file_obj.tell()
try:
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -0,0 +1,388 @@
import contextlib
import logging
import mimetypes
import os
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetInfo, Tag
from app.assets.database.queries import (
add_tags_to_asset_info,
fetch_asset_info_and_asset,
get_asset_by_hash,
get_asset_tags,
get_or_create_asset_info,
remove_missing_tag_for_asset_id,
set_asset_info_metadata,
set_asset_info_tags,
update_asset_info_timestamps,
upsert_asset,
upsert_cache_state,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_asset,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
asset_created = False
asset_updated = False
state_created = False
state_updated = False
asset_info_id: str | None = None
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
state_created, state_updated = upsert_cache_state(
session,
asset_id=asset.id,
file_path=locator,
mtime_ns=mtime_ns,
)
if info_name:
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=info_name,
preview_id=preview_id,
)
if info_created:
asset_info_id = info.id
else:
update_asset_info_timestamps(
session, asset_info=info, preview_id=preview_id
)
asset_info_id = info.id
norm = normalize_tags(list(tags))
if norm and asset_info_id:
if require_existing_tags:
_validate_tags_exist(session, norm)
add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
if asset_info_id:
_update_metadata_with_filename(
session,
asset_info_id=asset_info_id,
asset_id=asset.id,
info=info,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
state_created=state_created,
state_updated=state_updated,
asset_info_id=asset_info_id,
)
def register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=None,
)
if not info_created:
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata or {})
computed_filename = compute_filename_for_asset(session, asset.id)
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_asset_info_metadata(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.refresh(info)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename(
session: Session,
asset_info_id: str,
asset_id: str,
info: AssetInfo,
user_metadata: UserMetadata,
) -> None:
computed_filename = compute_filename_for_asset(session, asset_id)
current_meta = info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_asset_info_metadata(
session,
asset_info_id=asset_info_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
) -> UploadResult:
try:
digest = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = ingest_result.asset_info_id
if not info_id:
raise RuntimeError("failed to create asset metadata")
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=info_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
return UploadResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> UploadResult | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
result = register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@@ -0,0 +1,184 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = (
values[0],
values[1],
) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory()
if root == "input"
else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _compute_relative(child: str, parent: str) -> str:
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
)
def compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its best live cache state path."""
from app.assets.database.queries import list_cache_states_by_asset_id
from app.assets.helpers import select_best_live_path
primary_path = select_best_live_path(
list_cache_states_by_asset_id(session, asset_id=asset_id)
)
return compute_relative_filename(primary_path) if primary_path else None
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_asset_category_and_relative_path`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@@ -0,0 +1,126 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetInfo
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class AssetInfoData:
id: str
name: str
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
@dataclass(frozen=True)
class AssetDetailResult:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
state_created: bool
state_updated: bool
asset_info_id: str | None
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created_new: bool
def extract_info_data(info: AssetInfo) -> AssetInfoData:
return AssetInfoData(
id=info.id,
name=info.name,
user_metadata=info.user_metadata,
preview_id=info.preview_id,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@@ -0,0 +1,89 @@
from app.assets.database.queries import (
add_tags_to_asset_info,
get_asset_info_by_id,
list_tags_with_usage,
remove_tags_from_asset_info,
)
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
from app.database.db import create_session
def apply_tags(
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return AddTagsResult(
added=data["added"],
already_present=data["already_present"],
total_tags=data["total_tags"],
)
def remove_tags(
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return RemoveTagsResult(
removed=data["removed"],
not_present=data["not_present"],
total_tags=data["total_tags"],
)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total

View File

@@ -259,13 +259,3 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
for aid in ids:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -0,0 +1,14 @@
"""Helper functions for assets integration tests."""
import time
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -0,0 +1,20 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@@ -0,0 +1,142 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or setup_hash is None:
# Create asset with given hash (including None for null hash test)
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": None},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": None}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200

View File

@@ -0,0 +1,511 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import (
asset_info_exists_for_asset_id,
get_asset_info_by_id,
insert_asset_info,
get_or_create_asset_info,
update_asset_info_timestamps,
list_asset_infos_page,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
update_asset_info_access_time,
set_asset_info_metadata,
delete_asset_info_by_id,
set_asset_info_preview,
bulk_insert_asset_infos_ignore_conflicts,
get_asset_info_ids_by_ids,
ensure_tags_exist,
add_tags_to_asset_info,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestAssetInfoExistsForAssetId:
def test_returns_false_when_no_info(self, session: Session):
asset = _make_asset(session, "hash1")
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_info_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset)
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetAssetInfoById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None
def test_returns_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="myfile.txt")
result = get_asset_info_by_id(session, asset_info_id=info.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListAssetInfosPage:
def test_empty_db(self, session: Session):
infos, tag_map, total = list_asset_infos_page(session)
assert infos == []
assert tag_map == {}
assert total == 0
def test_returns_infos_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
infos, tag_map, total = list_asset_infos_page(session)
assert len(infos) == 1
assert infos[0].id == info.id
assert set(tag_map[info.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="model_v1.safetensors")
_make_asset_info(session, asset, name="config.json")
session.commit()
infos, _, total = list_asset_infos_page(session, name_contains="model")
assert total == 1
assert infos[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="public", owner_id="")
_make_asset_info(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
infos, _, total = list_asset_infos_page(session, owner_id="")
assert total == 1
assert infos[0].name == "public"
# Owner sees both
infos, _, total = list_asset_infos_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="tagged")
_make_asset_info(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"])
session.commit()
infos, _, total = list_asset_infos_page(session, include_tags=["wanted"])
assert total == 1
assert infos[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="keep")
info_exclude = _make_asset_info(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_asset_info(session, asset_info_id=info_exclude.id, tags=["bad"])
session.commit()
infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"])
assert total == 1
assert infos[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_asset_info(session, asset, name="small")
_make_asset_info(session, asset2, name="large")
session.commit()
infos, _, _ = list_asset_infos_page(session, sort="size", order="desc")
assert infos[0].name == "large"
infos, _, _ = list_asset_infos_page(session, sort="name", order="asc")
assert infos[0].name == "large"
class TestFetchAssetInfoAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"])
session.commit()
result = fetch_asset_info_asset_and_tags(session, info.id)
assert result is not None
ret_info, ret_asset, ret_tags = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchAssetInfoAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = fetch_asset_info_and_asset(session, asset_info_id=info.id)
assert result is not None
ret_info, ret_asset = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
class TestUpdateAssetInfoAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_time = info.last_access_time
session.commit()
import time
time.sleep(0.01)
update_asset_info_access_time(session, asset_info_id=info.id)
session.commit()
session.refresh(info)
assert info.last_access_time > original_time
class TestDeleteAssetInfoById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="")
assert result is True
assert get_asset_info_by_id(session, asset_info_id=info.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2")
assert result is False
assert get_asset_info_by_id(session, asset_info_id=info.id) is not None
class TestSetAssetInfoPreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None)
session.commit()
session.refresh(info)
assert info.preview_id is None
def test_raises_for_nonexistent_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent")
class TestInsertAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert info is not None
assert info.name == "test.bin"
assert info.owner_id == "user1"
def test_returns_none_on_conflict(self, session: Session):
asset = _make_asset(session, "hash1")
insert_asset_info(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Attempt duplicate with same (asset_id, owner_id, name)
result = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
assert result is None
class TestGetOrCreateAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info, created = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert info.name == "new.bin"
def test_returns_existing_info(self, session: Session):
asset = _make_asset(session, "hash1")
info1, created1 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
info2, created2 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is False
assert info1.id == info2.id
class TestUpdateAssetInfoTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_updated_at = info.updated_at
session.commit()
time.sleep(0.01)
update_asset_info_timestamps(session, info)
session.commit()
session.refresh(info)
assert info.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
update_asset_info_timestamps(session, info, preview_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
class TestSetAssetInfoMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"old": "data"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {}
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_metadata(
session, asset_info_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertAssetInfosIgnoreConflicts:
def test_inserts_multiple_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2 # existing + new, not 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_asset_infos_ignore_conflicts(session, [])
assert session.query(AssetInfo).count() == 0
class TestGetAssetInfoIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="a.bin")
info2 = _make_asset_info(session, asset, name="b.bin")
session.commit()
found = get_asset_info_ids_by_ids(session, [info1.id, info2.id, "nonexistent"])
assert found == {info1.id, info2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_asset_info_ids_by_ids(session, [])
assert found == set()

View File

@@ -0,0 +1,416 @@
"""Tests for cache_state query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries import (
list_cache_states_by_asset_id,
upsert_cache_state,
delete_cache_states_outside_prefixes,
get_orphaned_seed_asset_ids,
delete_assets_by_ids,
get_cache_states_for_prefixes,
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
bulk_insert_cache_states_ignore_conflicts,
get_cache_states_by_paths_and_asset_ids,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_cache_state(
session: Session,
asset: Asset,
file_path: str,
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetCacheState:
state = AssetCacheState(
asset_id=asset.id,
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
)
session.add(state)
session.flush()
return state
class TestListCacheStatesByAssetId:
def test_returns_empty_for_no_states(self, session: Session):
asset = _make_asset(session, "hash1")
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
assert list(states) == []
def test_returns_states_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/path/a.bin")
_make_cache_state(session, asset, "/path/b.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
paths = [s.file_path for s in states]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_states(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path/asset1.bin")
_make_cache_state(session, asset2, "/path/asset2.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset1.id)
paths = [s.file_path for s in states]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
state = _make_cache_state(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([state])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
state_verified = _make_cache_state(
session, asset, str(verified_file), needs_verify=False
)
state_unverified = _make_cache_state(
session, asset, str(unverified_file), needs_verify=True
)
session.commit()
states = [state_unverified, state_verified]
result = select_best_live_path(states)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all states need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
state = _make_cache_state(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([state])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle states with None file_path."""
class MockState:
file_path = None
needs_verify = False
result = select_best_live_path([MockState()])
assert result == ""
class TestUpsertCacheState:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New state creation
(None, 12345, True, False, 12345),
# Existing state, same mtime - no update
(100, 100, False, False, 100),
# Existing state, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_state", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
# Create initial state if needed
if initial_mtime is not None:
upsert_cache_state(session, asset_id=asset.id, file_path=file_path, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_cache_state(
session, asset_id=asset.id, file_path=file_path, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
assert state.mtime_ns == final_mtime
class TestDeleteCacheStatesOutsidePrefixes:
def test_deletes_states_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_cache_state(session, asset, valid_path)
_make_cache_state(session, asset, invalid_path)
session.commit()
deleted = delete_cache_states_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert deleted == 1
remaining = session.query(AssetCacheState).all()
assert len(remaining) == 1
assert remaining[0].file_path == valid_path
def test_empty_prefixes_deletes_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
deleted = delete_cache_states_outside_prefixes(session, [])
assert deleted == 0
class TestGetOrphanedSeedAssetIds:
def test_returns_orphaned_seed_assets(self, session: Session):
# Seed asset (hash=None) with no cache states
orphan = _make_asset(session, hash_val=None)
# Seed asset with cache state (not orphaned)
with_state = _make_asset(session, hash_val=None)
_make_cache_state(session, with_state, "/has/state.bin")
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
orphaned = get_orphaned_seed_asset_ids(session)
assert orphan.id in orphaned
assert with_state.id not in orphaned
class TestDeleteAssetsByIds:
def test_deletes_assets_and_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetInfo).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetCacheStatesForPrefixes:
def test_returns_states_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_cache_state(session, asset, path1, mtime_ns=100)
_make_cache_state(session, asset, path2, mtime_ns=200)
session.commit()
rows = get_cache_states_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
rows = get_cache_states_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin", needs_verify=False)
state2 = _make_cache_state(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_set_needs_verify(session, [state1.id, state2.id], True)
session.commit()
assert updated == 2
session.refresh(state1)
session.refresh(state2)
assert state1.needs_verify is True
assert state2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_set_needs_verify(session, [], True)
assert updated == 0
class TestDeleteCacheStatesByIds:
def test_deletes_states_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin")
_make_cache_state(session, asset, "/path2.bin")
session.commit()
deleted = delete_cache_states_by_ids(session, [state1.id])
session.commit()
assert deleted == 1
assert session.query(AssetCacheState).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_cache_states_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertCacheStatesIgnoreConflicts:
def test_inserts_multiple_states(self, session: Session):
asset = _make_asset(session, "hash1")
rows = [
{"asset_id": asset.id, "file_path": "/bulk1.bin", "mtime_ns": 100},
{"asset_id": asset.id, "file_path": "/bulk2.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
rows = [
{"asset_id": asset.id, "file_path": "/existing.bin", "mtime_ns": 999},
{"asset_id": asset.id, "file_path": "/new.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
existing = session.query(AssetCacheState).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_cache_states_ignore_conflicts(session, [])
assert session.query(AssetCacheState).count() == 0
class TestGetCacheStatesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
_make_cache_state(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_cache_states_by_paths_and_asset_ids(session, {})
assert winners == set()

View File

@@ -0,0 +1,184 @@
"""Tests for metadata filtering logic in asset_info queries."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import list_asset_infos_page
from app.assets.database.queries.asset_info import convert_metadata_to_rows
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
if metadata:
for key, val in metadata.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetInfoMeta(
asset_info_id=info.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return info
class TestMetadataFilterByType:
"""Table-driven tests for metadata filtering by different value types."""
@pytest.mark.parametrize(
"match_meta,nomatch_meta,filter_key,filter_val",
[
# String matching
({"category": "models"}, {"category": "images"}, "category", "models"),
# Integer matching
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
# Float matching
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
# Boolean True matching
({"enabled": True}, {"enabled": False}, "enabled", True),
# Boolean False matching
({"enabled": False}, {"enabled": True}, "enabled", False),
],
ids=["string", "int", "float", "bool_true", "bool_false"],
)
def test_filter_matches_correct_value(
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", match_meta)
_make_asset_info(session, asset, "nomatch", nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 1
assert infos[0].name == "match"
@pytest.mark.parametrize(
"stored_meta,filter_key,filter_val",
[
# String no match
({"category": "models"}, "category", "other"),
# Int no match
({"epoch": 5}, "epoch", 99),
# Float no match
({"score": 0.5}, "score", 0.99),
],
ids=["string_no_match", "int_no_match", "float_no_match"],
)
def test_filter_returns_empty_when_no_match(
self, session: Session, stored_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "item", stored_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 0
class TestMetadataFilterNull:
"""Tests for null/missing key filtering."""
@pytest.mark.parametrize(
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
[
# Null matches missing key
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
# Null matches explicit null
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
],
ids=["missing_key", "explicit_null"],
)
def test_null_filter_matches(
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, match_name, match_meta)
_make_asset_info(session, asset, nomatch_name, nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={filter_key: None})
assert total == 1
assert infos[0].name == match_name
class TestMetadataFilterList:
"""Tests for list-based (OR) filtering."""
def test_filter_by_list_matches_any(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "cat_a", {"category": "a"})
_make_asset_info(session, asset, "cat_b", {"category": "b"})
_make_asset_info(session, asset, "cat_c", {"category": "c"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {i.name for i in infos}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
"""Tests for multiple filter keys (AND semantics)."""
def test_multiple_keys_must_all_match(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", {"type": "model", "version": 2})
_make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert infos[0].name == "match"
class TestMetadataFilterEmptyDict:
"""Tests for empty filter behavior."""
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "a", {"key": "val"})
_make_asset_info(session, asset, "b", {})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={})
assert total == 2

View File

@@ -0,0 +1,366 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, AssetInfoMeta, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_asset_tags,
set_asset_info_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetAssetTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
tags = get_asset_tags(session, asset_info_id=info.id)
assert tags == []
def test_returns_tags_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
])
session.flush()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetAssetInfoTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["added"]) == {"a", "b"}
assert result["removed"] == []
assert set(result["total"]) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"])
session.commit()
assert result["added"] == []
assert set(result["removed"]) == {"b", "c"}
assert result["total"] == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"])
session.commit()
assert result["added"] == ["c"]
assert result["removed"] == ["a"]
assert set(result["total"]) == {"b", "c"}
class TestAddTagsToAssetInfo:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert set(result["added"]) == {"x", "y"}
assert result["already_present"] == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"])
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert result["added"] == ["y"]
assert result["already_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestRemoveTagsFromAssetInfo:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"])
session.commit()
assert result["removed"] == ["a"]
assert result["not_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_info = _make_asset_info(session, asset, name="shared", owner_id="")
owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1")
add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"])
add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1
class TestBulkInsertTagsAndMeta:
def test_inserts_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
{"asset_info_id": info.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
def test_inserts_meta(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
meta_rows = [
{
"asset_info_id": info.id,
"key": "meta-key",
"ordinal": 0,
"val_str": "meta-value",
"val_num": None,
"val_bool": None,
"val_json": None,
},
]
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "meta-key"
assert meta[0].val_str == "meta-value"
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing-tag"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing-tag"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
# Should still have only one tag link
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="existing-tag").all()
assert len(links) == 1
# Origin should be original, not overwritten
assert links[0].origin == "manual"
def test_empty_lists_is_noop(self, session: Session):
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
assert session.query(AssetInfoTag).count() == 0
assert session.query(AssetInfoMeta).count() == 0

View File

@@ -0,0 +1 @@
# Service layer tests

View File

@@ -0,0 +1,48 @@
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def db_engine():
"""In-memory SQLite engine for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
"""Session fixture for tests that need direct DB access."""
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def mock_create_session(db_engine):
"""Patch create_session to use our in-memory database."""
from contextlib import contextmanager
from sqlalchemy.orm import Session as SASession
@contextmanager
def _create_session():
with SASession(db_engine) as sess:
yield sess
with patch("app.assets.services.ingest.create_session", _create_session), \
patch("app.assets.services.asset_management.create_session", _create_session), \
patch("app.assets.services.tagging.create_session", _create_session):
yield _create_session
@pytest.fixture
def temp_dir():
"""Temporary directory for file operations."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)

View File

@@ -0,0 +1,264 @@
"""Tests for asset_management services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import (
get_asset_detail,
update_asset_metadata,
delete_asset_reference,
set_asset_preview,
)
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestGetAssetDetail:
def test_returns_none_for_nonexistent(self, mock_create_session):
result = get_asset_detail(asset_info_id="nonexistent")
assert result is None
def test_returns_asset_with_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
result = get_asset_detail(asset_info_id=info.id)
assert result is not None
assert result.info.id == info.id
assert result.asset.hash == asset.hash
assert set(result.tags) == {"alpha", "beta"}
def test_respects_owner_visibility(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
# Wrong owner cannot see
result = get_asset_detail(asset_info_id=info.id, owner_id="user2")
assert result is None
# Correct owner can see
result = get_asset_detail(asset_info_id=info.id, owner_id="user1")
assert result is not None
class TestUpdateAssetMetadata:
def test_updates_name(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="old_name.bin")
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
name="new_name.bin",
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.name == "new_name.bin"
def test_updates_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["old"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["old"])
session.commit()
result = update_asset_metadata(
asset_info_id=info.id,
tags=["new1", "new2"],
)
assert set(result.tags) == {"new1", "new2"}
assert "old" not in result.tags
def test_updates_user_metadata(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
user_metadata={"key": "value", "num": 42},
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.user_metadata["key"] == "value"
assert updated_info.user_metadata["num"] == 42
def test_raises_for_nonexistent(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
update_asset_metadata(asset_info_id="nonexistent", name="fail")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
update_asset_metadata(
asset_info_id=info.id,
name="new",
owner_id="user2",
)
class TestDeleteAssetReference:
def test_deletes_asset_info(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=False,
)
assert result is True
assert session.get(AssetInfo, info_id) is None
def test_returns_false_for_nonexistent(self, mock_create_session):
result = delete_asset_reference(
asset_info_id="nonexistent",
owner_id="",
)
assert result is False
def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="user2",
)
assert result is False
assert session.get(AssetInfo, info_id) is not None
def test_keeps_asset_if_other_infos_exist(self, mock_create_session, session: Session):
asset = _make_asset(session)
info1 = _make_asset_info(session, asset, name="info1")
_make_asset_info(session, asset, name="info2") # Second info keeps asset alive
asset_id = asset.id
session.commit()
delete_asset_reference(
asset_info_id=info1.id,
owner_id="",
delete_content_if_orphan=True,
)
# Asset should still exist
assert session.get(Asset, asset_id) is not None
def test_deletes_orphaned_asset(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
asset_id = asset.id
info_id = info.id
session.commit()
delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=True,
)
# Both info and asset should be gone
assert session.get(AssetInfo, info_id) is None
assert session.get(Asset, asset_id) is None
class TestSetAssetPreview:
def test_sets_preview(self, mock_create_session, session: Session):
asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info_id = info.id
preview_id = preview_asset.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=preview_id,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id == preview_id
def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
info_id = info.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=None,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id is None
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
set_asset_preview(asset_info_id="nonexistent")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
set_asset_preview(
asset_info_id=info.id,
preview_asset_id=None,
owner_id="user2",
)

View File

@@ -0,0 +1,227 @@
"""Tests for ingest services."""
from pathlib import Path
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, Tag
from app.assets.database.queries import get_asset_tags
from app.assets.services import ingest_file_from_path, register_existing_asset
class TestIngestFileFromPath:
def test_creates_asset_and_cache_state(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
file_path.write_bytes(b"test content")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:abc123",
size_bytes=12,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
assert result.asset_created is True
assert result.state_created is True
assert result.asset_info_id is None # no info_name provided
# Verify DB state
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].hash == "blake3:abc123"
states = session.query(AssetCacheState).all()
assert len(states) == 1
assert states[0].file_path == str(file_path)
def test_creates_asset_info_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"model data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:def456",
size_bytes=10,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
info_name="My Model",
owner_id="user1",
)
assert result.asset_created is True
assert result.asset_info_id is not None
info = session.query(AssetInfo).first()
assert info is not None
assert info.name == "My Model"
assert info.owner_id == "user1"
def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "tagged.bin"
file_path.write_bytes(b"data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:ghi789",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Tagged Asset",
tags=["models", "checkpoints"],
)
assert result.asset_info_id is not None
# Verify tags were created and linked
tags = session.query(Tag).all()
tag_names = {t.name for t in tags}
assert "models" in tag_names
assert "checkpoints" in tag_names
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
assert set(asset_tags) == {"models", "checkpoints"}
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "dup.bin"
file_path.write_bytes(b"content")
# First ingest
r1 = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000000,
)
assert r1.asset_created is True
# Second ingest with same hash - should update, not create
r2 = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000001, # different mtime
)
assert r2.asset_created is False
assert r2.state_updated is True or r2.state_created is False
# Still only one asset
assets = session.query(Asset).all()
assert len(assets) == 1
def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data")
# Create a preview asset first
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset)
session.commit()
preview_id = preview_asset.id
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:main",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="With Preview",
preview_id=preview_id,
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id == preview_id
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "bad_preview.bin"
file_path.write_bytes(b"data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:badpreview",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Bad Preview",
preview_id="nonexistent-uuid",
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id is None
class TestRegisterExistingAsset:
def test_creates_info_for_existing_asset(self, mock_create_session, session: Session):
# Create existing asset
asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = register_existing_asset(
asset_hash="blake3:existing",
name="Registered Asset",
user_metadata={"key": "value"},
tags=["models"],
)
assert result.created is True
assert "models" in result.tags
# Verify by re-fetching from DB
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Registered Asset").all()
assert len(infos) == 1
def test_returns_existing_info(self, mock_create_session, session: Session):
# Create asset and info
asset = Asset(hash="blake3:withinfo", size_bytes=512)
session.add(asset)
session.flush()
from app.assets.helpers import get_utc_now
info = AssetInfo(
owner_id="",
name="Existing Info",
asset_id=asset.id,
created_at=get_utc_now(),
updated_at=get_utc_now(),
last_access_time=get_utc_now(),
)
session.add(info)
session.flush() # Flush to get the ID
info_id = info.id
session.commit()
result = register_existing_asset(
asset_hash="blake3:withinfo",
name="Existing Info",
owner_id="",
)
assert result.created is False
# Verify only one AssetInfo exists for this name
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Existing Info").all()
assert len(infos) == 1
assert infos[0].id == info_id
def test_raises_for_nonexistent_hash(self, mock_create_session):
with pytest.raises(ValueError, match="No asset with hash"):
register_existing_asset(
asset_hash="blake3:doesnotexist",
name="Fail",
)
def test_applies_tags_to_new_info(self, mock_create_session, session: Session):
asset = Asset(hash="blake3:tagged", size_bytes=256)
session.add(asset)
session.commit()
result = register_existing_asset(
asset_hash="blake3:tagged",
name="Tagged Info",
tags=["alpha", "beta"],
)
assert result.created is True
assert set(result.tags) == {"alpha", "beta"}

View File

@@ -0,0 +1,197 @@
"""Tests for tagging services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import apply_tags, remove_tags, list_tags
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestApplyTags:
def test_adds_new_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["alpha", "beta"],
)
assert set(result.added) == {"alpha", "beta"}
assert result.already_present == []
assert set(result.total_tags) == {"alpha", "beta"}
def test_reports_already_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing"])
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["existing", "new"],
)
assert result.added == ["new"]
assert result.already_present == ["existing"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
apply_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
apply_tags(
asset_info_id=info.id,
tags=["new"],
owner_id="user2",
)
class TestRemoveTags:
def test_removes_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["a", "b", "c"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["a", "b"],
)
assert set(result.removed) == {"a", "b"}
assert result.not_present == []
assert result.total_tags == ["c"]
def test_reports_not_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["present"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["present"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["present", "absent"],
)
assert result.removed == ["present"]
assert result.not_present == ["absent"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
remove_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
remove_tags(
asset_info_id=info.id,
tags=["x"],
owner_id="user2",
)
class TestListTags:
def test_returns_tags_with_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags()
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_excludes_zero_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags(include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, _ = list_tags(prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags(order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_pagination(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a", "b", "c", "d", "e"])
session.commit()
rows, total = list_tags(limit=2, offset=1, order="name_asc")
assert total == 5
assert len(rows) == 2
names = [name for name, _, _ in rows]
assert names == ["b", "c"]
def test_clamps_limit(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a"])
session.commit()
# Service should clamp limit to max 1000
rows, _ = list_tags(limit=2000)
assert len(rows) <= 1000

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_create_from_hash_success(
@@ -126,42 +126,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api
assert body == b""
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
bogus = str(uuid.uuid4())
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
@pytest.mark.parametrize(
"method,endpoint_template,payload,expected_status,error_code",
[
# Delete nonexistent asset
("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"),
# Bad hash algorithm in from-hash
(
"post",
"/api/assets/from-hash",
{"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]},
400,
"INVALID_BODY",
),
# Get with bad UUID format
("get", "/api/assets/not-a-uuid", None, 404, None),
# Get content with bad UUID format
("get", "/api/assets/not-a-uuid/content", None, 404, None),
],
ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"],
)
def test_error_responses(
http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code
):
# Replace {uuid} placeholder with a random UUID for delete test
endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4()))
url = f"{api_base}{endpoint}"
if method == "get":
r = http.get(url, timeout=120)
elif method == "post":
r = http.post(url, json=payload, timeout=120)
elif method == "delete":
r = http.delete(url, timeout=120)
assert r.status_code == expected_status
if error_code:
body = r.json()
assert body["error"]["code"] == error_code
def test_create_from_hash_invalid_json(http: requests.Session, api_base: str):
"""Invalid JSON body requires special handling (data= instead of json=)."""
r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
# Bad hash algorithm
bad = {
"hash": "sha256:" + "0" * 64,
"name": "x.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Invalid JSON body
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_JSON"
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
# All endpoints should be not found, as we UUID regex directly in the route definition.
bad_id = "not-a-uuid"
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
assert r1.status_code == 404
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
assert r3.status_code == 404
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_JSON"
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -6,7 +6,7 @@ from typing import Optional
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -1,6 +1,7 @@
import time
import uuid
import pytest
import requests
@@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse
assert b2["has_more"] is False
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
# limit too small
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
# bad metadata JSON
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.parametrize(
"params,error_code",
[
({"offset": "-1"}, "INVALID_QUERY"),
({"limit": "abc"}, "INVALID_QUERY"),
({"limit": "0"}, "INVALID_QUERY"),
({"metadata_filter": "{not json"}, "INVALID_QUERY"),
],
ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"],
)
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code):
r = http.get(api_base + "/api/assets", params=params, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == error_code
def test_list_assets_name_contains_literal_underscore(

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
@pytest.fixture